diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index 61609f080e80677c807a8497d3dc3eb5b6dc6146..9e26e768962bcc06761a0e6a92c49c91353026e0 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -45,6 +45,11 @@ public: typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType; typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType; +protected : + TrainVectorClassifier() : TrainVectorBase() + { + m_ClassifierCategory = Supervised; + } private: void DoTrainInit() @@ -78,7 +83,18 @@ private: // Nothing to do here } - void DoTrainExecute() + void DoBeforeTrainExecute() + { + // Enforce the need of class field name in supervised mode + featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); + + if( featuresInfo.m_SelectedCFieldIdx.empty() && m_ClassifierCategory == Supervised ) + { + otbAppLogFATAL( << "No field has been selected for data labelling!" ); + } + } + + void DoAfterTrainExecute() { ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionmatrix( predictedList, classificationListSamples.labeledListSample ); @@ -86,6 +102,28 @@ private: } + ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement) + { + ListSamples performanceSample; + ListSamples validationListSamples = ExtractListSamples( "valid.vd", "valid.layer", measurement ); + //Test the input validation set size + if( validationListSamples.labeledListSample->Size() != 0 ) + { + performanceSample.listSample = validationListSamples.listSample; + performanceSample.labeledListSample = validationListSamples.labeledListSample; + } + else + { + otbAppLogWARNING( + "The validation set is empty. The performance estimation is done using the input training set in this case." ); + performanceSample.listSample = trainingListSamples.listSample; + performanceSample.labeledListSample = trainingListSamples.labeledListSample; + } + + return performanceSample; + } + + ConfusionMatrixCalculatorType::Pointer ComputeConfusionmatrix(const TargetListSampleType::Pointer &predictedListSample, const TargetListSampleType::Pointer &performanceLabeledListSample) @@ -285,7 +323,6 @@ private: otbAppLogINFO( "Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str() ); } - }; } } diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx index eec6927bd0327e7f3c4bcf1bc68d200397cd5e1f..49acbbc2b3d6e320fcce46145d6a8311bedc2894 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx @@ -16,18 +16,15 @@ =========================================================================*/ #include "otbTrainVectorBase.h" -// Validation -#include "otbConfusionMatrixCalculator.h" - namespace otb { namespace Wrapper { -class TrainVectorClassifier : public TrainVectorBase +class TrainVectorClustering : public TrainVectorBase { public: - typedef TrainVectorClassifier Self; + typedef TrainVectorClustering Self; typedef TrainVectorBase Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; @@ -39,10 +36,31 @@ public: typedef Superclass::ListSampleType ListSampleType; typedef Superclass::TargetListSampleType TargetListSampleType; +protected : + TrainVectorClustering() : TrainVectorBase() + { + m_ClassifierCategory = Unsupervised; + } + private: void DoTrainInit() { - // Nothing to do here + SetName( "TrainVectorClustering" ); + SetDescription( "Train a classifier based on labeled or unlabeled geometries and a list of features to consider." ); + + SetDocName( "Train Vector Clustering" ); + SetDocLongDescription( "This application trains a classifier based on " + "labeled or unlabeled geometries and a list of features to consider for classification." ); + SetDocLimitations( " " ); + SetDocAuthors( "OTB Team" ); + SetDocSeeAlso( " " ); + + // Doc example parameter settings + SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); + SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); + SetDocExampleParameterValue( "io.out", "svmModel.svm" ); + SetDocExampleParameterValue( "feat", "perimeter area width" ); + } void DoTrainUpdateParameters() @@ -50,7 +68,12 @@ private: // Nothing to do here } - void DoTrainExecute() + void DoBeforeTrainExecute() + { + // Nothing to do here + } + + void DoAfterTrainExecute() { // Nothing to do here } diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index 0bff79f53827944fc27e6dc1a91a2814d98b34a2..e9d3be1eb5275af9b1cbab7be204d314fb4f6db7 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -48,11 +48,13 @@ bool IsNotAlphaNum(char c) class TrainVectorBase : public LearningApplicationBase<float, int> { public: + /** Standard class typedefs. */ typedef TrainVectorBase Self; typedef LearningApplicationBase<float, int> Superclass; typedef itk::SmartPointer <Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; + /** Standard macro */ itkTypeMacro(Self, Superclass) typedef Superclass::SampleType SampleType; @@ -96,29 +98,31 @@ protected: class FeaturesInfo { public: - /** Index for class field */ - std::vector<int> m_SelectedCFieldIdx; /** Selected Index */ std::vector<int> m_SelectedIdx; + /** Index for class field */ + std::vector<int> m_SelectedCFieldIdx; /** Selected class field name */ std::string m_SelectedCFieldName; /** Selected names */ std::vector <std::string> m_SelectedNames; unsigned int m_NbFeatures; - FeaturesInfo(std::vector <std::string> fieldNames, std::vector <std::string> cFieldNames, - std::vector<int> selectedIdx, std::vector<int> selectedCFieldIdx) - : m_SelectedIdx( selectedIdx ), m_SelectedCFieldIdx( selectedCFieldIdx ) + void SetFieldNames(std::vector <std::string> fieldNames, std::vector<int> selectedIdx) { + m_SelectedIdx = selectedIdx; m_NbFeatures = static_cast<unsigned int>(selectedIdx.size()); m_SelectedNames = std::vector<std::string>( m_NbFeatures ); for( unsigned int i = 0; i < m_NbFeatures; ++i ) { m_SelectedNames[i] = fieldNames[selectedIdx[i]]; } - + } + void SetClassFieldNames(std::vector<std::string> cFieldNames, std::vector<int> selectedCFieldIdx) + { + m_SelectedCFieldIdx = selectedCFieldIdx; + // Handle only one class field name, if several are provided only the first one is used. m_SelectedCFieldName = cFieldNames[selectedCFieldIdx.front()]; - } }; @@ -126,12 +130,11 @@ protected: protected: /** - * Function which extract and store all samples for Training, Classification and Validation. + * Function which extract and store all samples for Training and Classification. * \param measurement statics measurement (mean/stddev) * \param featuresInfo information about the features - * \return sample list used for training */ - virtual void ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + virtual void ExtractAllSamples(const StatisticsMeasurement &measurement); /** * Extract the training sample list @@ -139,60 +142,50 @@ protected: * \param featuresInfo information about the features * \return sample list used for training */ - virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement); /** - * Extract the validation sample list - * \param measurement statics measurement (mean/stddev) - * \param featuresInfo information about the features - * \return sample list used for validation - */ - virtual ListSamples ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); - - /** - * Extract the sample list classification + * Extract classification the sample list * \param measurement statics measurement (mean/stddev) * \param featuresInfo information about the features * \return sample list used for classification */ - virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement); + + + /** Extract samples from input file for corresponding field name + * + * \param parameterName the name of the input file option in the input application parameters + * \param parameterLayer the name of the layer option in the input application parameters + * \param measurement statics measurement (mean/stddev) + * \param nbFeatures the number of features. + * \return the list of samples and their corresponding labels. + */ + ListSamples + ExtractListSamples(std::string parameterName, std::string parameterLayer, const StatisticsMeasurement &measurement); + + + /** + * Retrieve statistics mean and standard deviation if input statistics are provided. + * Otherwise mean is set to 0 and standard deviation to 1 for each Features. + * \param nbFeatures + */ + StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures); ListSamples trainingListSamples; - ListSamples validationListSamples; ListSamples classificationListSamples; TargetListSampleType::Pointer predictedList; + FeaturesInfo featuresInfo; private: virtual void DoTrainInit() = 0; - virtual void DoTrainExecute() = 0; + virtual void DoBeforeTrainExecute() = 0; + virtual void DoAfterTrainExecute() = 0; virtual void DoTrainUpdateParameters() = 0; - void DoInit(); - void DoUpdateParameters(); - void DoExecute(); - - /** Extract samples from input file for corresponding field name - * - * \param parameterName the name of the input file option in the input application parameters - * \param parameterLayer the name of the layer option in the input application parameters - * \param measurement statics measurement (mean/stddev) - * \param nbFeatures the number of features. - * \return the list of samples and their corresponding labels. - */ - ListSamples ExtractListSamples(std::string parameterName, std::string parameterLayer, - const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); - - - - ListSamples ExtractClassificationListSamples(ListSamples &validationListSamples, ListSamples &trainingListSamples); - - - /** - * Retrieve statistics mean and standard deviation if input statistics are provided. - * Otherwise mean is set to 0 and standard deviation to 1 for each Features. - * \param nbFeatures - */ - StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures); + void DoInit() ITK_OVERRIDE; + void DoUpdateParameters() ITK_OVERRIDE; + void DoExecute() ITK_OVERRIDE; }; diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx index c45d51ad78666bc6e5f48909646ddf5795bc4211..97706bf56c20f4971670ad80b7d0537869f8cd3c 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx @@ -56,7 +56,10 @@ void TrainVectorBase::DoInit() MandatoryOff( "layer" ); SetDefaultParameterInt( "layer", 0 ); - //Can be in both Supervised and Unsupervised ? + AddParameter(ParameterType_ListView, "feat", "Field names for training features."); + SetParameterDescription("feat","List of field names in the input vector data to be used as features for training."); + + // Add validation data used to compute confusion matrix or contingence table AddParameter( ParameterType_Group, "valid", "Validation data" ); SetParameterDescription( "valid", "This group of parameters defines validation data." ); @@ -70,14 +73,13 @@ void TrainVectorBase::DoInit() MandatoryOff( "valid.layer" ); SetDefaultParameterInt( "valid.layer", 0 ); - AddParameter(ParameterType_ListView, "feat", "Field names for training features."); - SetParameterDescription("feat","List of field names in the input vector data to be used as features for training."); - + // Add class field if we used validation AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision"); SetParameterDescription("cfield","Field containing the class id for supervision. " "Only geometries with this field available will be taken into account."); SetListViewSingleSelectionMode("cfield",true); + // Add parameters for the classifier choice Superclass::DoInit(); @@ -92,7 +94,7 @@ void TrainVectorBase::DoUpdateParameters() { std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read ); - ogr::Layer layer = ogrDS->GetLayer( this->GetParameterInt( "layer" ) ); + ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) ); ogr::Feature feature = layer.ogr().GetNextFeature(); ClearChoices( "feat" ); @@ -109,12 +111,12 @@ void TrainVectorBase::DoUpdateParameters() if( fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) || fieldType == OFTReal ) { - std::string tmpKey = "feat." + key.substr( 0, end - key.begin() ); + std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); AddChoice( tmpKey, item ); } if( fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) ) { - std::string tmpKey = "cfield." + key.substr( 0, end - key.begin() ); + std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); AddChoice( tmpKey, item ); } } @@ -125,12 +127,9 @@ void TrainVectorBase::DoUpdateParameters() void TrainVectorBase::DoExecute() { - typedef int LabelPixelType; - typedef itk::FixedArray<LabelPixelType, 1> LabelSampleType; - typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType; + DoBeforeTrainExecute(); - FeaturesInfo featuresInfo( GetChoiceNames( "feat" ), GetChoiceNames( "cfield" ), GetSelectedItems( "feat" ), - GetSelectedItems( "cfield" ) ); + featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); // Check input parameters if( featuresInfo.m_SelectedIdx.empty() ) @@ -138,64 +137,34 @@ void TrainVectorBase::DoExecute() otbAppLogFATAL( << "No features have been selected to train the classifier on!" ); } - // Todo only Log warning and set CFieldName to 0, 1, 2, 3... (default behavior) - if( featuresInfo.m_SelectedCFieldIdx.empty() ) - { - otbAppLogFATAL( << "No field has been selected for data labelling!" ); - } - StatisticsMeasurement measurement = ComputeStatistics( featuresInfo.m_NbFeatures ); - ExtractSamples(measurement, featuresInfo); + ExtractAllSamples( measurement ); this->Train( trainingListSamples.listSample, trainingListSamples.labeledListSample, GetParameterString( "io.out" ) ); predictedList = TargetListSampleType::New(); this->Classify( classificationListSamples.listSample, predictedList, GetParameterString( "io.out" ) ); - DoTrainExecute(); + DoAfterTrainExecute(); } -void TrainVectorBase::ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +void TrainVectorBase::ExtractAllSamples(const StatisticsMeasurement &measurement) { - trainingListSamples = ExtractTrainingListSamples(measurement, featuresInfo); - validationListSamples = ExtractValidationListSamples(measurement, featuresInfo); - classificationListSamples = ExtractClassificationListSamples(measurement, featuresInfo); + trainingListSamples = ExtractTrainingListSamples(measurement); + classificationListSamples = ExtractClassificationListSamples(measurement); } TrainVectorBase::ListSamples -TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement) { - return ExtractListSamples( "io.vd", "layer", measurement, featuresInfo ); + return ExtractListSamples( "io.vd", "layer", measurement); } TrainVectorBase::ListSamples -TrainVectorBase::ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) -{ - return ExtractListSamples( "valid.vd", "valid.layer", measurement, featuresInfo ); -} - - -TrainVectorBase::ListSamples -TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &itkNotUsed(measurement)) { - ListSamples performanceSample; - - //Test the input validation set size - if( validationListSamples.labeledListSample->Size() != 0 ) - { - performanceSample.listSample = validationListSamples.listSample; - performanceSample.labeledListSample = validationListSamples.labeledListSample; - } - else - { - otbAppLogWARNING( - "The validation set is empty. The performance estimation is done using the input training set in this case." ); - performanceSample.listSample = trainingListSamples.listSample; - performanceSample.labeledListSample = trainingListSamples.labeledListSample; - } - - return performanceSample; + return trainingListSamples; } @@ -224,7 +193,7 @@ TrainVectorBase::ComputeStatistics(unsigned int nbFeatures) TrainVectorBase::ListSamples TrainVectorBase::ExtractListSamples(std::string parameterName, std::string parameterLayer, - const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) + const StatisticsMeasurement &measurement) { ListSamples listSamples; if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) ) @@ -249,12 +218,15 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param } // Check all needed fields are present : - // - check class field + // - check class field if we use supervised classification or if class field name is not empty int cFieldIndex = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedCFieldName.c_str() ); - if( cFieldIndex < 0 ) + if( cFieldIndex < 0 && !featuresInfo.m_SelectedCFieldName.empty()) + { otbAppLogFATAL( "The field name for class label (" << featuresInfo.m_SelectedCFieldName << ") has not been found in the vector file " << validFileList[k] ); + } + // - check feature fields std::vector<int> featureFieldIndex( featuresInfo.m_NbFeatures, -1 ); for( unsigned int i = 0; i < featuresInfo.m_NbFeatures; i++ ) @@ -266,18 +238,22 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param << validFileList[k] ); } + while( goesOn ) { - if( feature.ogr().IsFieldSet( cFieldIndex ) ) - { - MeasurementType mv; - mv.SetSize( featuresInfo.m_NbFeatures ); - for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx ) - mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] ); + // Retrieve all the features for each field in the ogr layer. + MeasurementType mv; + mv.SetSize( featuresInfo.m_NbFeatures ); + for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx ) + mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] ); - input->PushBack( mv ); + input->PushBack( mv ); + + if( feature.ogr().IsFieldSet( cFieldIndex ) ) target->PushBack( feature.ogr().GetFieldAsInteger( cFieldIndex ) ); - } + else + target->PushBack( 0 ); + feature = layer.ogr().GetNextFeature(); goesOn = feature.addr() != 0; }