From bec02c1fa304d05298e3f10442052fcf9f5444a0 Mon Sep 17 00:00:00 2001 From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr> Date: Wed, 15 Feb 2017 15:11:51 +0100 Subject: [PATCH] REFAC: Refactoring TrainVectorClassifier. Inherit TrainVectorClassifier from TrainVectorBase and use Non-Virtual Function Idiom to provide common behavior for Unsupervised and Supervised classification. --- .../app/otbTrainVectorClassifier.cxx | 676 +++++------------- .../app/otbTrainVectorClustering.cxx | 64 ++ .../include/otbTrainVectorBase.h | 207 ++++++ .../include/otbTrainVectorBase.txx | 305 ++++++++ 4 files changed, 772 insertions(+), 480 deletions(-) create mode 100644 Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx create mode 100644 Modules/Applications/AppClassification/include/otbTrainVectorBase.h create mode 100644 Modules/Applications/AppClassification/include/otbTrainVectorBase.txx diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index fa5209552e..61609f080e 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -14,60 +14,30 @@ PURPOSE. See the above copyright notices for more information. =========================================================================*/ -#include "otbWrapperApplication.h" -#include "otbWrapperApplicationFactory.h" - -#include "otbLearningApplicationBase.h" - -#include "otbOGRDataSourceWrapper.h" -#include "otbOGRFeatureWrapper.h" -#include "otbStatisticsXMLFileWriter.h" - -#include "itkVariableLengthVector.h" -#include "otbStatisticsXMLFileReader.h" - -#include "itkListSample.h" -#include "otbShiftScaleSampleListFilter.h" +#include "otbTrainVectorBase.h" // Validation #include "otbConfusionMatrixCalculator.h" -#include <algorithm> -#include <locale> - namespace otb { namespace Wrapper { -/** Utility function to negate std::isalnum */ -bool IsNotAlphaNum(char c) - { - return !std::isalnum(c); - } - -class TrainVectorClassifier : public LearningApplicationBase<float,int> +class TrainVectorClassifier : public TrainVectorBase { public: typedef TrainVectorClassifier Self; - typedef LearningApplicationBase<float, int> Superclass; + typedef TrainVectorBase Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - itkNewMacro(Self) - - itkTypeMacro(Self, Superclass) - - typedef Superclass::SampleType SampleType; - typedef Superclass::ListSampleType ListSampleType; - typedef Superclass::TargetListSampleType TargetListSampleType; - typedef Superclass::SampleImageType SampleImageType; - - typedef double ValueType; - typedef itk::VariableLengthVector<ValueType> MeasurementType; + itkNewMacro( Self ) - typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader; + itkTypeMacro( Self, Superclass ) - typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType; + typedef Superclass::SampleType SampleType; + typedef Superclass::ListSampleType ListSampleType; + typedef Superclass::TargetListSampleType TargetListSampleType; // Estimate performance on validation sample typedef otb::ConfusionMatrixCalculator<TargetListSampleType, TargetListSampleType> ConfusionMatrixCalculatorType; @@ -75,503 +45,249 @@ public: typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType; typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType; + private: - void DoInit() + void DoTrainInit() { - SetName("TrainVectorClassifier"); - SetDescription("Train a classifier based on labeled geometries and a list of features to consider."); - - SetDocName("Train Vector Classifier"); - SetDocLongDescription("This application trains a classifier based on " - "labeled geometries and a list of features to consider for classification."); - SetDocLimitations(" "); - SetDocAuthors("OTB Team"); - SetDocSeeAlso(" "); - - //Group IO - AddParameter(ParameterType_Group, "io", "Input and output data"); - SetParameterDescription("io", "This group of parameters allows setting input and output data."); - - AddParameter(ParameterType_InputVectorDataList, "io.vd", "Input Vector Data"); - SetParameterDescription("io.vd", "Input geometries used for training (note : all geometries from the layer will be used)"); - - AddParameter(ParameterType_InputFilename, "io.stats", "Input XML image statistics file"); - MandatoryOff("io.stats"); - SetParameterDescription("io.stats", "XML file containing mean and variance of each feature."); - - AddParameter(ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix"); - SetParameterDescription("io.confmatout", "Output file containing the confusion matrix (.csv format)."); - MandatoryOff("io.confmatout"); - - AddParameter(ParameterType_OutputFilename, "io.out", "Output model"); - SetParameterDescription("io.out", "Output file containing the model estimated (.txt format)."); - - AddParameter(ParameterType_ListView, "feat", "Field names for training features."); - SetParameterDescription("feat","List of field names in the input vector data to be used as features for training."); - - AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision"); - SetParameterDescription("cfield","Field containing the class id for supervision. " - "Only geometries with this field available will be taken into account."); - SetListViewSingleSelectionMode("cfield",true); - - AddParameter(ParameterType_Int, "layer", "Layer Index"); - SetParameterDescription("layer", "Index of the layer to use in the input vector file."); - MandatoryOff("layer"); - SetDefaultParameterInt("layer",0); - - AddParameter(ParameterType_Group, "valid", "Validation data"); - SetParameterDescription("valid", "This group of parameters defines validation data."); - - AddParameter(ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data"); - SetParameterDescription("valid.vd", "Geometries used for validation " - "(must contain the same fields used for training, all geometries from the layer will be used)"); - MandatoryOff("valid.vd"); - - AddParameter(ParameterType_Int, "valid.layer", "Layer Index"); - SetParameterDescription("valid.layer", "Index of the layer to use in the validation vector file."); - MandatoryOff("valid.layer"); - SetDefaultParameterInt("valid.layer",0); - - // Add parameters for the classifier choice - Superclass::DoInit(); - - AddRANDParameter(); + SetName( "TrainVectorClassifier" ); + SetDescription( "Train a classifier based on labeled geometries and a list of features to consider." ); + + SetDocName( "Train Vector Classifier" ); + SetDocLongDescription( "This application trains a classifier based on " + "labeled geometries and a list of features to consider for classification." ); + SetDocLimitations( " " ); + SetDocAuthors( "OTB Team" ); + SetDocSeeAlso( " " ); + + // Add a new parameter to compute confusion matrix + AddParameter( ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix" ); + SetParameterDescription( "io.confmatout", "Output file containing the confusion matrix (.csv format)." ); + MandatoryOff( "io.confmatout" ); + // Doc example parameter settings - SetDocExampleParameterValue("io.vd", "vectorData.shp"); - SetDocExampleParameterValue("io.stats", "meanVar.xml"); - SetDocExampleParameterValue("io.out", "svmModel.svm"); - SetDocExampleParameterValue("feat", "perimeter area width"); - SetDocExampleParameterValue("cfield", "predicted"); + SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); + SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); + SetDocExampleParameterValue( "io.out", "svmModel.svm" ); + SetDocExampleParameterValue( "feat", "perimeter area width" ); + SetDocExampleParameterValue( "cfield", "predicted" ); + } - void DoUpdateParameters() + void DoTrainUpdateParameters() { - if ( HasValue("io.vd") ) - { - std::vector<std::string> vectorFileList = GetParameterStringList("io.vd"); - ogr::DataSource::Pointer ogrDS = - ogr::DataSource::New(vectorFileList[0], ogr::DataSource::Modes::Read); - ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer")); - ogr::Feature feature = layer.ogr().GetNextFeature(); - - ClearChoices("feat"); - ClearChoices("cfield"); - - for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++) - { - std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef(); - key = item; - std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum); - std::transform(key.begin(), end, key.begin(), tolower); - - OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); - - if(fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType) || fieldType == OFTReal) - { - std::string tmpKey="feat."+key.substr(0, end - key.begin()); - AddChoice(tmpKey,item); - } - if(fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType)) - { - std::string tmpKey="cfield."+key.substr(0, end - key.begin()); - AddChoice(tmpKey,item); - } - } - } + // Nothing to do here + } + + void DoTrainExecute() + { + ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionmatrix( predictedList, + classificationListSamples.labeledListSample ); + WriteConfusionMatrix( confMatCalc ); } -void LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc) -{ - ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix(); + ConfusionMatrixCalculatorType::Pointer + ComputeConfusionmatrix(const TargetListSampleType::Pointer &predictedListSample, + const TargetListSampleType::Pointer &performanceLabeledListSample) + { + ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New(); + + otbAppLogINFO( "Predicted list size : " << predictedListSample->Size() ); + otbAppLogINFO( "ValidationLabeledListSample size : " << performanceLabeledListSample->Size() ); + confMatCalc->SetReferenceLabels( performanceLabeledListSample ); + confMatCalc->SetProducedLabels( predictedListSample ); + confMatCalc->Compute(); - // Compute minimal width - size_t minwidth = 0; + otbAppLogINFO( "training performances" ); + LogConfusionMatrix( confMatCalc ); - for (unsigned int i = 0; i < matrix.Rows(); i++) - { - for (unsigned int j = 0; j < matrix.Cols(); j++) + for( unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++ ) { - std::ostringstream os; - os << matrix(i, j); - size_t size = os.str().size(); + ConfusionMatrixCalculatorType::ClassLabelType classLabel = confMatCalc->GetMapOfIndices()[itClasses]; - if (size > minwidth) - { - minwidth = size; - } + otbAppLogINFO( "Precision of class [" << classLabel << "] vs all: " << confMatCalc->GetPrecisions()[itClasses] ); + otbAppLogINFO( "Recall of class [" << classLabel << "] vs all: " << confMatCalc->GetRecalls()[itClasses] ); + otbAppLogINFO( + "F-score of class [" << classLabel << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n" ); } - } + otbAppLogINFO( "Global performance, Kappa index: " << confMatCalc->GetKappaIndex() ); + return confMatCalc; + } - MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices(); + /** + * Write the confidence matrix into a file if output is provided. + * \param confMatCalc the input matrix to write. + */ + void WriteConfusionMatrix(const ConfusionMatrixCalculatorType::Pointer &confMatCalc) + { + if( this->HasValue( "io.confmatout" ) ) + { + // Writing the confusion matrix in the output .CSV file - MapOfIndicesType::const_iterator it = mapOfIndices.begin(); - MapOfIndicesType::const_iterator end = mapOfIndices.end(); + MapOfIndicesType::iterator itMapOfIndicesValid, itMapOfIndicesPred; + ClassLabelType labelValid = 0; - for (; it != end; ++it) - { - std::ostringstream os; - os << "[" << it->second << "]"; + ConfusionMatrixType confusionMatrix = confMatCalc->GetConfusionMatrix(); + MapOfIndicesType mapOfIndicesValid = confMatCalc->GetMapOfIndices(); - size_t size = os.str().size(); - if (size > minwidth) - { - minwidth = size; - } - } + unsigned long nbClassesPred = mapOfIndicesValid.size(); - // Generate matrix string, with 'minwidth' as size specifier - std::ostringstream os; + ///////////////////////////////////////////// + // Filling the 2 headers for the output file + const std::string commentValidStr = "#Reference labels (rows):"; + const std::string commentPredStr = "#Produced labels (columns):"; + const char separatorChar = ','; + std::ostringstream ossHeaderValidLabels, ossHeaderPredLabels; - // Header line - for (size_t i = 0; i < minwidth; ++i) - os << " "; - os << " "; - - it = mapOfIndices.begin(); - end = mapOfIndices.end(); - for (; it != end; ++it) - { - os << "[" << it->second << "]" << " "; - } - - os << std::endl; - - // Each line of confusion matrix - for (unsigned int i = 0; i < matrix.Rows(); i++) - { - ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i]; - os << "[" << std::setw(minwidth - 2) << label << "]" << " "; - for (unsigned int j = 0; j < matrix.Cols(); j++) - { - os << std::setw(minwidth) << matrix(i, j) << " "; - } - os << std::endl; - } + // Filling ossHeaderValidLabels and ossHeaderPredLabels for the output file + ossHeaderValidLabels << commentValidStr; + ossHeaderPredLabels << commentPredStr; - otbAppLogINFO("Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str()); -} + itMapOfIndicesValid = mapOfIndicesValid.begin(); + while( itMapOfIndicesValid != mapOfIndicesValid.end() ) + { + // labels labelValid of mapOfIndicesValid are already sorted in otbConfusionMatrixCalculator + labelValid = itMapOfIndicesValid->second; -void DoExecute() - { - typedef int LabelPixelType; - typedef itk::FixedArray<LabelPixelType,1> LabelSampleType; - typedef itk::Statistics::ListSample <LabelSampleType> LabelListSampleType; - - // Prepare selected field names (their position may change between two inputs) - std::vector<int> selectedIdx = GetSelectedItems("feat"); - std::vector<int> selectedCFieldIdx = GetSelectedItems("cfield"); - - if(selectedIdx.empty()) - { - otbAppLogFATAL(<<"No features have been selected to train the classifier on!"); - } - - if(selectedCFieldIdx.empty()) - { - otbAppLogFATAL(<<"No field has been selected for data labelling!"); - } - - const unsigned int nbFeatures = selectedIdx.size(); - std::vector<std::string> fieldNames = GetChoiceNames("feat"); - std::vector<std::string> cFieldNames = GetChoiceNames("cfield"); - std::vector<std::string> selectedNames(nbFeatures); - for (unsigned int i=0 ; i<nbFeatures ; i++) - { - selectedNames[i] = fieldNames[selectedIdx[i]]; - } - - std::string selectedCFieldName = cFieldNames[selectedCFieldIdx.front()]; - - std::vector<int> featureFieldIndex(nbFeatures, -1); - int cFieldIndex = -1; - - // Statistics for shift/scale - MeasurementType meanMeasurementVector; - MeasurementType stddevMeasurementVector; - if (HasValue("io.stats") && IsParameterEnabled("io.stats")) - { - StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); - std::string XMLfile = GetParameterString("io.stats"); - statisticsReader->SetFileName(XMLfile); - meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); - stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); - } - else - { - meanMeasurementVector.SetSize(nbFeatures); - meanMeasurementVector.Fill(0.); - stddevMeasurementVector.SetSize(nbFeatures); - stddevMeasurementVector.Fill(1.); - } - - ListSampleType::Pointer input = ListSampleType::New(); - LabelListSampleType::Pointer target = LabelListSampleType::New(); - input->SetMeasurementVectorSize(nbFeatures); - - std::vector<std::string> vectorFileList = GetParameterStringList("io.vd"); - for (unsigned int k=0 ; k<vectorFileList.size() ; k++) - { - otbAppLogINFO("Reading input vector file "<<k+1<<"/"<<vectorFileList.size()); - ogr::DataSource::Pointer source = ogr::DataSource::New(vectorFileList[k], ogr::DataSource::Modes::Read); - ogr::Layer layer = source->GetLayer(this->GetParameterInt("layer")); - ogr::Feature feature = layer.ogr().GetNextFeature(); - bool goesOn = feature.addr() != 0; - if (!goesOn) - { - otbAppLogWARNING("The layer "<<GetParameterInt("layer")<<" of " - <<vectorFileList[k]<<" is empty, input is skipped."); - continue; - } + otbAppLogINFO( "mapOfIndicesValid[" << itMapOfIndicesValid->first << "] = " << labelValid ); - // Check all needed fields are present : - // - check class field - cFieldIndex = feature.ogr().GetFieldIndex(selectedCFieldName.c_str()); - if (cFieldIndex < 0) - otbAppLogFATAL("The field name for class label ("<<selectedCFieldName - <<") has not been found in the input vector file "<<vectorFileList[k]); - // - check feature fields - for (unsigned int i=0 ; i<nbFeatures ; i++) - { - featureFieldIndex[i] = feature.ogr().GetFieldIndex(selectedNames[i].c_str()); - if (featureFieldIndex[i] < 0) - otbAppLogFATAL("The field name for feature "<<selectedNames[i] - <<" has not been found in the input vector file "<<vectorFileList[k]); - } + ossHeaderValidLabels << labelValid; + ossHeaderPredLabels << labelValid; - while(goesOn) - { - if(feature.ogr().IsFieldSet(cFieldIndex)) - { - MeasurementType mv; - mv.SetSize(nbFeatures); - for(unsigned int idx=0; idx < nbFeatures; ++idx) - mv[idx] = feature.ogr().GetFieldAsDouble(featureFieldIndex[idx]); + ++itMapOfIndicesValid; - input->PushBack(mv); - target->PushBack(feature.ogr().GetFieldAsInteger(cFieldIndex)); - } - feature = layer.ogr().GetNextFeature(); - goesOn = feature.addr() != 0; - } - } - - ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New(); - trainingShiftScaleFilter->SetInput(input); - trainingShiftScaleFilter->SetShifts(meanMeasurementVector); - trainingShiftScaleFilter->SetScales(stddevMeasurementVector); - trainingShiftScaleFilter->Update(); - - ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput(); - TargetListSampleType::Pointer trainingLabeledListSample = target; - - //-------------------------- - // Estimate model - //-------------------------- - this->Train(trainingListSample,trainingLabeledListSample,GetParameterString("io.out")); - - //-------------------------- - // Performances estimation - //-------------------------- - ListSampleType::Pointer validationListSample=ListSampleType::New(); - TargetListSampleType::Pointer validationLabeledListSample = TargetListSampleType::New(); - - // Import validation data - if (HasValue("valid.vd") && IsParameterEnabled("valid.vd")) - { - input = ListSampleType::New(); - target = LabelListSampleType::New(); - input->SetMeasurementVectorSize(nbFeatures); - - std::vector<std::string> validFileList = this->GetParameterStringList("valid.vd"); - for (unsigned int k=0 ; k<validFileList.size() ; k++) - { - otbAppLogINFO("Reading validation vector file "<<k+1<<"/"<<validFileList.size()); - ogr::DataSource::Pointer source = ogr::DataSource::New(validFileList[k], ogr::DataSource::Modes::Read); - ogr::Layer layer = source->GetLayer(this->GetParameterInt("valid.layer")); - ogr::Feature feature = layer.ogr().GetNextFeature(); - bool goesOn = feature.addr() != 0; - if (!goesOn) - { - otbAppLogWARNING("The layer "<<GetParameterInt("valid.layer")<<" of " - <<validFileList[k]<<" is empty, input is skipped."); - continue; + if( itMapOfIndicesValid != mapOfIndicesValid.end() ) + { + ossHeaderValidLabels << separatorChar; + ossHeaderPredLabels << separatorChar; + } + else + { + ossHeaderValidLabels << std::endl; + ossHeaderPredLabels << std::endl; + } } - // Check all needed fields are present : - // - check class field - cFieldIndex = feature.ogr().GetFieldIndex(selectedCFieldName.c_str()); - if (cFieldIndex < 0) - otbAppLogFATAL("The field name for class label ("<<selectedCFieldName - <<") has not been found in the validation vector file "<<validFileList[k]); - // - check feature fields - for (unsigned int i=0 ; i<nbFeatures ; i++) - { - featureFieldIndex[i] = feature.ogr().GetFieldIndex(selectedNames[i].c_str()); - if (featureFieldIndex[i] < 0) - otbAppLogFATAL("The field name for feature "<<selectedNames[i] - <<" has not been found in the validation vector file "<<validFileList[k]); - } + std::ofstream outFile; + outFile.open( this->GetParameterString( "io.confmatout" ).c_str() ); + outFile << std::fixed; + outFile.precision( 10 ); + + ///////////////////////////////////// + // Writing the 2 headers + outFile << ossHeaderValidLabels.str(); + outFile << ossHeaderPredLabels.str(); + ///////////////////////////////////// + + unsigned int indexLabelValid = 0, indexLabelPred = 0; - while(goesOn) + for( itMapOfIndicesValid = mapOfIndicesValid.begin(); + itMapOfIndicesValid != mapOfIndicesValid.end(); ++itMapOfIndicesValid ) { - if(feature.ogr().IsFieldSet(cFieldIndex)) - { - MeasurementType mv; - mv.SetSize(nbFeatures); - for(unsigned int idx=0; idx < nbFeatures; ++idx) - mv[idx] = feature.ogr().GetFieldAsDouble(featureFieldIndex[idx]); + indexLabelPred = 0; - input->PushBack(mv); - target->PushBack(feature.ogr().GetFieldAsInteger(cFieldIndex)); + for( itMapOfIndicesPred = mapOfIndicesValid.begin(); + itMapOfIndicesPred != mapOfIndicesValid.end(); ++itMapOfIndicesPred ) + { + // Writing the confusion matrix (sorted in otbConfusionMatrixCalculator) in the output file + outFile << confusionMatrix( indexLabelValid, indexLabelPred ); + if( indexLabelPred < ( nbClassesPred - 1 ) ) + { + outFile << separatorChar; + } + else + { + outFile << std::endl; + } + ++indexLabelPred; } - feature = layer.ogr().GetNextFeature(); - goesOn = feature.addr() != 0; + + ++indexLabelValid; } + + outFile.close(); } + } + + /** + * Display the log of the confusion matrix computed with + * \param confMatCalc the input confusion matrix to display + */ + void LogConfusionMatrix(ConfusionMatrixCalculatorType *confMatCalc) + { + ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix(); + + // Compute minimal width + size_t minwidth = 0; - ShiftScaleFilterType::Pointer validShiftScaleFilter = ShiftScaleFilterType::New(); - validShiftScaleFilter->SetInput(input); - validShiftScaleFilter->SetShifts(meanMeasurementVector); - validShiftScaleFilter->SetScales(stddevMeasurementVector); - validShiftScaleFilter->Update(); - - validationListSample = validShiftScaleFilter->GetOutput(); - validationLabeledListSample = target; - } - - //Test the input validation set size - TargetListSampleType::Pointer predictedList = TargetListSampleType::New(); - ListSampleType::Pointer performanceListSample; - TargetListSampleType::Pointer performanceLabeledListSample; - if(validationLabeledListSample->Size() != 0) - { - performanceListSample = validationListSample; - performanceLabeledListSample = validationLabeledListSample; - } - else - { - otbAppLogWARNING("The validation set is empty. The performance estimation is done using the input training set in this case."); - performanceListSample = trainingListSample; - performanceLabeledListSample = trainingLabeledListSample; - } - - this->Classify(performanceListSample, predictedList, GetParameterString("io.out")); - - ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New(); - - otbAppLogINFO("Predicted list size : " << predictedList->Size()); - otbAppLogINFO("ValidationLabeledListSample size : " << performanceLabeledListSample->Size()); - confMatCalc->SetReferenceLabels(performanceLabeledListSample); - confMatCalc->SetProducedLabels(predictedList); - confMatCalc->Compute(); - - otbAppLogINFO("training performances"); - LogConfusionMatrix(confMatCalc); - - for (unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++) - { - ConfusionMatrixCalculatorType::ClassLabelType classLabel = confMatCalc->GetMapOfIndices()[itClasses]; - - otbAppLogINFO("Precision of class [" << classLabel << "] vs all: " << confMatCalc->GetPrecisions()[itClasses]); - otbAppLogINFO("Recall of class [" << classLabel << "] vs all: " << confMatCalc->GetRecalls()[itClasses]); - otbAppLogINFO( - "F-score of class [" << classLabel << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n"); - } - otbAppLogINFO("Global performance, Kappa index: " << confMatCalc->GetKappaIndex()); - - - if (this->HasValue("io.confmatout")) - { - // Writing the confusion matrix in the output .CSV file - - MapOfIndicesType::iterator itMapOfIndicesValid, itMapOfIndicesPred; - ClassLabelType labelValid = 0; - - ConfusionMatrixType confusionMatrix = confMatCalc->GetConfusionMatrix(); - MapOfIndicesType mapOfIndicesValid = confMatCalc->GetMapOfIndices(); - - unsigned int nbClassesPred = mapOfIndicesValid.size(); - - ///////////////////////////////////////////// - // Filling the 2 headers for the output file - const std::string commentValidStr = "#Reference labels (rows):"; - const std::string commentPredStr = "#Produced labels (columns):"; - const char separatorChar = ','; - std::ostringstream ossHeaderValidLabels, ossHeaderPredLabels; - - // Filling ossHeaderValidLabels and ossHeaderPredLabels for the output file - ossHeaderValidLabels << commentValidStr; - ossHeaderPredLabels << commentPredStr; - - itMapOfIndicesValid = mapOfIndicesValid.begin(); - - while (itMapOfIndicesValid != mapOfIndicesValid.end()) + for( unsigned int i = 0; i < matrix.Rows(); i++ ) { - // labels labelValid of mapOfIndicesValid are already sorted in otbConfusionMatrixCalculator - labelValid = itMapOfIndicesValid->second; + for( unsigned int j = 0; j < matrix.Cols(); j++ ) + { + std::ostringstream os; + os << matrix( i, j ); + size_t size = os.str().size(); - otbAppLogINFO("mapOfIndicesValid[" << itMapOfIndicesValid->first << "] = " << labelValid); + if( size > minwidth ) + { + minwidth = size; + } + } + } - ossHeaderValidLabels << labelValid; - ossHeaderPredLabels << labelValid; + MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices(); - ++itMapOfIndicesValid; + MapOfIndicesType::const_iterator it = mapOfIndices.begin(); + MapOfIndicesType::const_iterator end = mapOfIndices.end(); - if (itMapOfIndicesValid != mapOfIndicesValid.end()) - { - ossHeaderValidLabels << separatorChar; - ossHeaderPredLabels << separatorChar; - } - else + for( ; it != end; ++it ) + { + std::ostringstream os; + os << "[" << it->second << "]"; + + size_t size = os.str().size(); + if( size > minwidth ) { - ossHeaderValidLabels << std::endl; - ossHeaderPredLabels << std::endl; + minwidth = size; } } - std::ofstream outFile; - outFile.open(this->GetParameterString("io.confmatout").c_str()); - outFile << std::fixed; - outFile.precision(10); - - ///////////////////////////////////// - // Writing the 2 headers - outFile << ossHeaderValidLabels.str(); - outFile << ossHeaderPredLabels.str(); - ///////////////////////////////////// + // Generate matrix string, with 'minwidth' as size specifier + std::ostringstream os; - unsigned int indexLabelValid = 0, indexLabelPred = 0; + // Header line + for( size_t i = 0; i < minwidth; ++i ) + os << " "; + os << " "; - for (itMapOfIndicesValid = mapOfIndicesValid.begin(); itMapOfIndicesValid != mapOfIndicesValid.end(); ++itMapOfIndicesValid) + it = mapOfIndices.begin(); + end = mapOfIndices.end(); + for( ; it != end; ++it ) { - indexLabelPred = 0; + os << "[" << it->second << "]" << " "; + } + + os << std::endl; - for (itMapOfIndicesPred = mapOfIndicesValid.begin(); itMapOfIndicesPred != mapOfIndicesValid.end(); ++itMapOfIndicesPred) + // Each line of confusion matrix + for( unsigned int i = 0; i < matrix.Rows(); i++ ) + { + ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i]; + os << "[" << std::setw( minwidth - 2 ) << label << "]" << " "; + for( unsigned int j = 0; j < matrix.Cols(); j++ ) { - // Writing the confusion matrix (sorted in otbConfusionMatrixCalculator) in the output file - outFile << confusionMatrix(indexLabelValid, indexLabelPred); - if (indexLabelPred < (nbClassesPred - 1)) - { - outFile << separatorChar; - } - else - { - outFile << std::endl; - } - ++indexLabelPred; + os << std::setw( minwidth ) << matrix( i, j ) << " "; } - - ++indexLabelValid; + os << std::endl; } - outFile.close(); - } // END if (this->HasValue("io.confmatout")) + otbAppLogINFO( "Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str() ); } + }; } } -OTB_APPLICATION_EXPORT(otb::Wrapper::TrainVectorClassifier) +OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorClassifier ) diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx new file mode 100644 index 0000000000..eec6927bd0 --- /dev/null +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx @@ -0,0 +1,64 @@ +/*========================================================================= + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#include "otbTrainVectorBase.h" + +// Validation +#include "otbConfusionMatrixCalculator.h" + +namespace otb +{ +namespace Wrapper +{ + +class TrainVectorClassifier : public TrainVectorBase +{ +public: + typedef TrainVectorClassifier Self; + typedef TrainVectorBase Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + itkNewMacro( Self ) + + itkTypeMacro( Self, Superclass ) + + typedef Superclass::SampleType SampleType; + typedef Superclass::ListSampleType ListSampleType; + typedef Superclass::TargetListSampleType TargetListSampleType; + +private: + void DoTrainInit() + { + // Nothing to do here + } + + void DoTrainUpdateParameters() + { + // Nothing to do here + } + + void DoTrainExecute() + { + // Nothing to do here + } + + + +}; +} +} + +OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorClustering ) diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h new file mode 100644 index 0000000000..0bff79f538 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -0,0 +1,207 @@ +/*========================================================================= + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#ifndef otbTrainVectorBase_h +#define otbTrainVectorBase_h + +#include "otbLearningApplicationBase.h" +#include "otbWrapperApplication.h" +#include "otbWrapperApplicationFactory.h" + +#include "otbOGRDataSourceWrapper.h" +#include "otbOGRFeatureWrapper.h" +#include "otbStatisticsXMLFileWriter.h" + +#include "itkVariableLengthVector.h" +#include "otbStatisticsXMLFileReader.h" + +#include "itkListSample.h" +#include "otbShiftScaleSampleListFilter.h" + +#include <algorithm> +#include <locale> + +namespace otb +{ +namespace Wrapper +{ + +/** Utility function to negate std::isalnum */ +bool IsNotAlphaNum(char c) +{ + return !std::isalnum( c ); +} + +class TrainVectorBase : public LearningApplicationBase<float, int> +{ +public: + typedef TrainVectorBase Self; + typedef LearningApplicationBase<float, int> Superclass; + typedef itk::SmartPointer <Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + itkTypeMacro(Self, Superclass) + + typedef Superclass::SampleType SampleType; + typedef Superclass::ListSampleType ListSampleType; + typedef Superclass::TargetListSampleType TargetListSampleType; + + typedef double ValueType; + typedef itk::VariableLengthVector <ValueType> MeasurementType; + + typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader; + + typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType; + +protected: + + /** Class used to store statistics Measurment (mean/stddev) */ + class StatisticsMeasurement + { + public: + MeasurementType meanMeasurementVector; + MeasurementType stddevMeasurementVector; + }; + + /** Class used to store a list of sample and the corresponding label */ + class ListSamples + { + public: + ListSampleType::Pointer listSample; + TargetListSampleType::Pointer labeledListSample; + ListSamples() + { + listSample = ListSampleType::New(); + labeledListSample = TargetListSampleType::New(); + } + }; + + /** + * Features information class used to store informations + * about the field and class name/id of an input vector + */ + class FeaturesInfo + { + public: + /** Index for class field */ + std::vector<int> m_SelectedCFieldIdx; + /** Selected Index */ + std::vector<int> m_SelectedIdx; + /** Selected class field name */ + std::string m_SelectedCFieldName; + /** Selected names */ + std::vector <std::string> m_SelectedNames; + unsigned int m_NbFeatures; + + FeaturesInfo(std::vector <std::string> fieldNames, std::vector <std::string> cFieldNames, + std::vector<int> selectedIdx, std::vector<int> selectedCFieldIdx) + : m_SelectedIdx( selectedIdx ), m_SelectedCFieldIdx( selectedCFieldIdx ) + { + m_NbFeatures = static_cast<unsigned int>(selectedIdx.size()); + m_SelectedNames = std::vector<std::string>( m_NbFeatures ); + for( unsigned int i = 0; i < m_NbFeatures; ++i ) + { + m_SelectedNames[i] = fieldNames[selectedIdx[i]]; + } + + m_SelectedCFieldName = cFieldNames[selectedCFieldIdx.front()]; + + } + }; + + +protected: + + /** + * Function which extract and store all samples for Training, Classification and Validation. + * \param measurement statics measurement (mean/stddev) + * \param featuresInfo information about the features + * \return sample list used for training + */ + virtual void ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + + /** + * Extract the training sample list + * \param measurement statics measurement (mean/stddev) + * \param featuresInfo information about the features + * \return sample list used for training + */ + virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + + /** + * Extract the validation sample list + * \param measurement statics measurement (mean/stddev) + * \param featuresInfo information about the features + * \return sample list used for validation + */ + virtual ListSamples ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + + /** + * Extract the sample list classification + * \param measurement statics measurement (mean/stddev) + * \param featuresInfo information about the features + * \return sample list used for classification + */ + virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + + ListSamples trainingListSamples; + ListSamples validationListSamples; + ListSamples classificationListSamples; + TargetListSampleType::Pointer predictedList; + +private: + virtual void DoTrainInit() = 0; + virtual void DoTrainExecute() = 0; + virtual void DoTrainUpdateParameters() = 0; + + void DoInit(); + void DoUpdateParameters(); + void DoExecute(); + + /** Extract samples from input file for corresponding field name + * + * \param parameterName the name of the input file option in the input application parameters + * \param parameterLayer the name of the layer option in the input application parameters + * \param measurement statics measurement (mean/stddev) + * \param nbFeatures the number of features. + * \return the list of samples and their corresponding labels. + */ + ListSamples ExtractListSamples(std::string parameterName, std::string parameterLayer, + const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo); + + + + ListSamples ExtractClassificationListSamples(ListSamples &validationListSamples, ListSamples &trainingListSamples); + + + /** + * Retrieve statistics mean and standard deviation if input statistics are provided. + * Otherwise mean is set to 0 and standard deviation to 1 for each Features. + * \param nbFeatures + */ + StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures); + +}; + +} +} + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbTrainVectorBase.txx" +#endif + +#endif + diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx new file mode 100644 index 0000000000..c45d51ad78 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx @@ -0,0 +1,305 @@ +/*========================================================================= + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#ifndef otbTrainVectorBase_txx +#define otbTrainVectorBase_txx + +#include "otbTrainVectorBase.h" + +namespace otb +{ +namespace Wrapper +{ + +void TrainVectorBase::DoInit() +{ + SetName( "TrainVectorClassifier" ); + SetDescription( "Train a classifier based on labeled geometries and a list of features to consider." ); + + SetDocName( "Train Vector Classifier" ); + SetDocLongDescription( "This application trains a classifier based on " + "labeled geometries and a list of features to consider for classification." ); + SetDocLimitations( " " ); + SetDocAuthors( "OTB Team" ); + SetDocSeeAlso( " " ); + + // Common Parameters for all Learning Application + AddParameter( ParameterType_Group, "io", "Input and output data" ); + SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); + + AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" ); + SetParameterDescription( "io.vd", + "Input geometries used for training (note : all geometries from the layer will be used)" ); + + AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" ); + MandatoryOff( "io.stats" ); + SetParameterDescription( "io.stats", "XML file containing mean and variance of each feature." ); + + AddParameter( ParameterType_OutputFilename, "io.out", "Output model" ); + SetParameterDescription( "io.out", "Output file containing the model estimated (.txt format)." ); + + AddParameter( ParameterType_Int, "layer", "Layer Index" ); + SetParameterDescription( "layer", "Index of the layer to use in the input vector file." ); + MandatoryOff( "layer" ); + SetDefaultParameterInt( "layer", 0 ); + + //Can be in both Supervised and Unsupervised ? + AddParameter( ParameterType_Group, "valid", "Validation data" ); + SetParameterDescription( "valid", "This group of parameters defines validation data." ); + + AddParameter( ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data" ); + SetParameterDescription( "valid.vd", "Geometries used for validation " + "(must contain the same fields used for training, all geometries from the layer will be used)" ); + MandatoryOff( "valid.vd" ); + + AddParameter( ParameterType_Int, "valid.layer", "Layer Index" ); + SetParameterDescription( "valid.layer", "Index of the layer to use in the validation vector file." ); + MandatoryOff( "valid.layer" ); + SetDefaultParameterInt( "valid.layer", 0 ); + + AddParameter(ParameterType_ListView, "feat", "Field names for training features."); + SetParameterDescription("feat","List of field names in the input vector data to be used as features for training."); + + AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision"); + SetParameterDescription("cfield","Field containing the class id for supervision. " + "Only geometries with this field available will be taken into account."); + SetListViewSingleSelectionMode("cfield",true); + + // Add parameters for the classifier choice + Superclass::DoInit(); + + AddRANDParameter(); + + DoTrainInit(); +} + +void TrainVectorBase::DoUpdateParameters() +{ + if( HasValue( "io.vd" ) ) + { + std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read ); + ogr::Layer layer = ogrDS->GetLayer( this->GetParameterInt( "layer" ) ); + ogr::Feature feature = layer.ogr().GetNextFeature(); + + ClearChoices( "feat" ); + ClearChoices( "cfield" ); + + for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ ) + { + std::string key, item = feature.ogr().GetFieldDefnRef( iField )->GetNameRef(); + key = item; + std::string::iterator end = std::remove_if( key.begin(), key.end(), IsNotAlphaNum ); + std::transform( key.begin(), end, key.begin(), tolower ); + + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef( iField )->GetType(); + + if( fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) || fieldType == OFTReal ) + { + std::string tmpKey = "feat." + key.substr( 0, end - key.begin() ); + AddChoice( tmpKey, item ); + } + if( fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) ) + { + std::string tmpKey = "cfield." + key.substr( 0, end - key.begin() ); + AddChoice( tmpKey, item ); + } + } + } + + DoTrainUpdateParameters(); +} + +void TrainVectorBase::DoExecute() +{ + typedef int LabelPixelType; + typedef itk::FixedArray<LabelPixelType, 1> LabelSampleType; + typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType; + + FeaturesInfo featuresInfo( GetChoiceNames( "feat" ), GetChoiceNames( "cfield" ), GetSelectedItems( "feat" ), + GetSelectedItems( "cfield" ) ); + + // Check input parameters + if( featuresInfo.m_SelectedIdx.empty() ) + { + otbAppLogFATAL( << "No features have been selected to train the classifier on!" ); + } + + // Todo only Log warning and set CFieldName to 0, 1, 2, 3... (default behavior) + if( featuresInfo.m_SelectedCFieldIdx.empty() ) + { + otbAppLogFATAL( << "No field has been selected for data labelling!" ); + } + + StatisticsMeasurement measurement = ComputeStatistics( featuresInfo.m_NbFeatures ); + ExtractSamples(measurement, featuresInfo); + + this->Train( trainingListSamples.listSample, trainingListSamples.labeledListSample, GetParameterString( "io.out" ) ); + + predictedList = TargetListSampleType::New(); + this->Classify( classificationListSamples.listSample, predictedList, GetParameterString( "io.out" ) ); + + DoTrainExecute(); +} + + +void TrainVectorBase::ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +{ + trainingListSamples = ExtractTrainingListSamples(measurement, featuresInfo); + validationListSamples = ExtractValidationListSamples(measurement, featuresInfo); + classificationListSamples = ExtractClassificationListSamples(measurement, featuresInfo); +} + +TrainVectorBase::ListSamples +TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +{ + return ExtractListSamples( "io.vd", "layer", measurement, featuresInfo ); +} + +TrainVectorBase::ListSamples +TrainVectorBase::ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +{ + return ExtractListSamples( "valid.vd", "valid.layer", measurement, featuresInfo ); +} + + +TrainVectorBase::ListSamples +TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +{ + ListSamples performanceSample; + + //Test the input validation set size + if( validationListSamples.labeledListSample->Size() != 0 ) + { + performanceSample.listSample = validationListSamples.listSample; + performanceSample.labeledListSample = validationListSamples.labeledListSample; + } + else + { + otbAppLogWARNING( + "The validation set is empty. The performance estimation is done using the input training set in this case." ); + performanceSample.listSample = trainingListSamples.listSample; + performanceSample.labeledListSample = trainingListSamples.labeledListSample; + } + + return performanceSample; +} + + +TrainVectorBase::StatisticsMeasurement +TrainVectorBase::ComputeStatistics(unsigned int nbFeatures) +{ + StatisticsMeasurement measurement = StatisticsMeasurement(); + if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) ) + { + StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); + std::string XMLfile = GetParameterString( "io.stats" ); + statisticsReader->SetFileName( XMLfile.c_str() ); + measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" ); + measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" ); + } + else + { + measurement.meanMeasurementVector.SetSize( nbFeatures ); + measurement.meanMeasurementVector.Fill( 0. ); + measurement.stddevMeasurementVector.SetSize( nbFeatures ); + measurement.stddevMeasurementVector.Fill( 1. ); + } + return measurement; +} + + +TrainVectorBase::ListSamples +TrainVectorBase::ExtractListSamples(std::string parameterName, std::string parameterLayer, + const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo) +{ + ListSamples listSamples; + if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) ) + { + ListSampleType::Pointer input = ListSampleType::New(); + TargetListSampleType::Pointer target = TargetListSampleType::New(); + input->SetMeasurementVectorSize( featuresInfo.m_NbFeatures ); + + std::vector<std::string> validFileList = this->GetParameterStringList( parameterName ); + for( unsigned int k = 0; k < validFileList.size(); k++ ) + { + otbAppLogINFO( "Reading validation vector file " << k + 1 << "/" << validFileList.size() ); + ogr::DataSource::Pointer source = ogr::DataSource::New( validFileList[k], ogr::DataSource::Modes::Read ); + ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) ); + ogr::Feature feature = layer.ogr().GetNextFeature(); + bool goesOn = feature.addr() != 0; + if( !goesOn ) + { + otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << validFileList[k] + << " is empty, input is skipped." ); + continue; + } + + // Check all needed fields are present : + // - check class field + int cFieldIndex = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedCFieldName.c_str() ); + if( cFieldIndex < 0 ) + otbAppLogFATAL( "The field name for class label (" << featuresInfo.m_SelectedCFieldName + << ") has not been found in the vector file " + << validFileList[k] ); + // - check feature fields + std::vector<int> featureFieldIndex( featuresInfo.m_NbFeatures, -1 ); + for( unsigned int i = 0; i < featuresInfo.m_NbFeatures; i++ ) + { + featureFieldIndex[i] = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedNames[i].c_str() ); + if( featureFieldIndex[i] < 0 ) + otbAppLogFATAL( "The field name for feature " << featuresInfo.m_SelectedNames[i] + << " has not been found in the vector file " + << validFileList[k] ); + } + + while( goesOn ) + { + if( feature.ogr().IsFieldSet( cFieldIndex ) ) + { + MeasurementType mv; + mv.SetSize( featuresInfo.m_NbFeatures ); + for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx ) + mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] ); + + input->PushBack( mv ); + target->PushBack( feature.ogr().GetFieldAsInteger( cFieldIndex ) ); + } + feature = layer.ogr().GetNextFeature(); + goesOn = feature.addr() != 0; + } + } + + ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New(); + shiftScaleFilter->SetInput( input ); + shiftScaleFilter->SetShifts( measurement.meanMeasurementVector ); + shiftScaleFilter->SetScales( measurement.stddevMeasurementVector ); + shiftScaleFilter->Update(); + + listSamples.listSample = shiftScaleFilter->GetOutput(); + listSamples.labeledListSample = target; + } + + return listSamples; +} + + +} +} + +#endif + + -- GitLab