From 9b617cfa5426d6543b31033c1967847ab7480906 Mon Sep 17 00:00:00 2001 From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr> Date: Wed, 5 Apr 2017 17:01:05 +0200 Subject: [PATCH] REFAC: Apply RFC 85 review --- .../app/otbTrainVectorClassifier.cxx | 40 +++---- .../include/otbTrainImagesBase.h | 2 +- .../include/otbTrainSharkKMeans.txx | 10 +- .../include/otbTrainVectorBase.h | 33 +++--- .../include/otbTrainVectorBase.txx | 108 +++++++++--------- 5 files changed, 89 insertions(+), 104 deletions(-) diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index c9926bd37f..935de9401c 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -49,44 +49,42 @@ public: typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType; typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType; -private: - void DoTrainInit() +protected: + void DoInit() { - // Nothing to do here + TrainVectorBase::DoInit(); } - void DoTrainUpdateParameters() + void DoUpdateParameters() { - // Nothing to do here + TrainVectorBase::DoUpdateParameters(); } - void DoBeforeTrainExecute() + void DoExecute() { // Enforce the need of class field name in supervised mode if (GetClassifierCategory() == Supervised) { - featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); + m_featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); - if( featuresInfo.m_SelectedCFieldIdx.empty() ) + if( m_featuresInfo.m_SelectedCFieldIdx.empty() ) { otbAppLogFATAL( << "No field has been selected for data labelling!" ); } } - } - void DoAfterTrainExecute() - { + TrainVectorBase::DoExecute(); - if (GetClassifierCategory() == Supervised) - { - ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( predictedList, - classificationListSamples.labeledListSample ); - WriteConfusionMatrix( confMatCalc ); - } - else - { - // TODO Compute Contingency Table - } + if (GetClassifierCategory() == Supervised) + { + ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( m_predictedList, + m_classificationSamplesWithLabel.labeledListSample ); + WriteConfusionMatrix( confMatCalc ); + } + else + { + // TODO Compute Contingency Table + } } diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h index af17e02e5d..a01eb6c430 100644 --- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h @@ -35,7 +35,7 @@ namespace Wrapper { /** \class TrainImagesBase - * \brief Base class for the TrainImagesBaseClassifier and Clustering + * \brief Base class for the TrainImagesClassifier * * This class intends to hold common input/output parameters and * composite application connection for both supervised and unsupervised diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx index 0dba5c6ab7..e3cf63af98 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx @@ -35,11 +35,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() "See complete documentation here " "\\url{http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html}.\n " ); //MaxNumberOfIterations - AddParameter( ParameterType_Int, "classifier.sharkkm.nbmaxiter", + AddParameter( ParameterType_Int, "classifier.sharkkm.maxiter", "Maximum number of iteration for the kmeans algorithm." ); - SetParameterInt( "classifier.sharkkm.nbmaxiter", 10 ); - SetMinimumParameterIntValue( "classifier.sharkkm.nbmaxiter", 0 ); - SetParameterDescription( "classifier.sharkkm.nbmaxiter", + SetParameterInt( "classifier.sharkkm.maxiter", 10 ); + SetMinimumParameterIntValue( "classifier.sharkkm.maxiter", 0 ); + SetParameterDescription( "classifier.sharkkm.maxiter", "The maximum number of iteration for the kmeans algorithm. 0=unlimited" ); //MaxNumberOfIterations @@ -55,7 +55,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath) { - unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.nbmaxiter" ) )); + unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.maxiter" ) )); unsigned int k = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.k" ) )); typename SharkKMeansType::Pointer classifier = SharkKMeansType::New(); diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index bdd844ca1d..f01a99737b 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -74,7 +74,7 @@ public: protected: /** Class used to store statistics Measurment (mean/stddev) */ - class StatisticsMeasurement + class ShiftScaleParameters { public: MeasurementType meanMeasurementVector; @@ -82,12 +82,12 @@ protected: }; /** Class used to store a list of sample and the corresponding label */ - class ListSamples + class SamplesWithLabel { public: ListSampleType::Pointer listSample; TargetListSampleType::Pointer labeledListSample; - ListSamples() + SamplesWithLabel() { listSample = ListSampleType::New(); labeledListSample = TargetListSampleType::New(); @@ -137,7 +137,7 @@ protected: * \param measurement statics measurement (mean/stddev) * \param featuresInfo information about the features */ - virtual void ExtractAllSamples(const StatisticsMeasurement &measurement); + virtual void ExtractAllSamples(const ShiftScaleParameters &measurement); /** * Extract the training sample list @@ -145,7 +145,7 @@ protected: * \param featuresInfo information about the features * \return sample list used for training */ - virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement); + virtual SamplesWithLabel ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement); /** * Extract classification the sample list @@ -153,7 +153,7 @@ protected: * \param featuresInfo information about the features * \return sample list used for classification */ - virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement); + virtual SamplesWithLabel ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement); /** Extract samples from input file for corresponding field name @@ -164,8 +164,8 @@ protected: * \param nbFeatures the number of features. * \return the list of samples and their corresponding labels. */ - ListSamples - ExtractListSamples(std::string parameterName, std::string parameterLayer, const StatisticsMeasurement &measurement); + SamplesWithLabel + ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement); /** @@ -173,18 +173,12 @@ protected: * Otherwise mean is set to 0 and standard deviation to 1 for each Features. * \param nbFeatures */ - StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures); + ShiftScaleParameters ComputeStatistics(unsigned int nbFeatures); - ListSamples trainingListSamples; - ListSamples classificationListSamples; - TargetListSampleType::Pointer predictedList; - FeaturesInfo featuresInfo; - -private: - virtual void DoTrainInit() = 0; - virtual void DoBeforeTrainExecute() = 0; - virtual void DoAfterTrainExecute() = 0; - virtual void DoTrainUpdateParameters() = 0; + SamplesWithLabel m_trainingSamplesWithLabel; + SamplesWithLabel m_classificationSamplesWithLabel; + TargetListSampleType::Pointer m_predictedList; + FeaturesInfo m_featuresInfo; void DoInit() ITK_OVERRIDE; void DoUpdateParameters() ITK_OVERRIDE; @@ -200,4 +194,3 @@ private: #endif #endif - diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx index 82d4f7028a..52263be4aa 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx @@ -102,7 +102,7 @@ void TrainVectorBase::DoInit() AddRANDParameter(); - DoTrainInit(); + DoInit(); } void TrainVectorBase::DoUpdateParameters() @@ -142,79 +142,75 @@ void TrainVectorBase::DoUpdateParameters() } } - DoTrainUpdateParameters(); + DoUpdateParameters(); } void TrainVectorBase::DoExecute() { - DoBeforeTrainExecute(); - - featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); + m_featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); // Check input parameters - if( featuresInfo.m_SelectedIdx.empty() ) + if( m_featuresInfo.m_SelectedIdx.empty() ) { otbAppLogFATAL( << "No features have been selected to train the classifier on!" ); } - StatisticsMeasurement measurement = ComputeStatistics( featuresInfo.m_NbFeatures ); + ShiftScaleParameters measurement = ComputeStatistics( m_featuresInfo.m_NbFeatures ); ExtractAllSamples( measurement ); - this->Train( trainingListSamples.listSample, trainingListSamples.labeledListSample, GetParameterString( "io.out" ) ); - - predictedList = TargetListSampleType::New(); - this->Classify( classificationListSamples.listSample, predictedList, GetParameterString( "io.out" ) ); + this->Train( m_trainingSamplesWithLabel.listSample, m_trainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) ); - DoAfterTrainExecute(); + m_predictedList = TargetListSampleType::New(); + this->Classify( m_classificationSamplesWithLabel.listSample, m_predictedList, GetParameterString( "io.out" ) ); } -void TrainVectorBase::ExtractAllSamples(const StatisticsMeasurement &measurement) +void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement) { - trainingListSamples = ExtractTrainingListSamples(measurement); - classificationListSamples = ExtractClassificationListSamples(measurement); + m_trainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement); + m_classificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement); } -TrainVectorBase::ListSamples -TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement) +TrainVectorBase::SamplesWithLabel +TrainVectorBase::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement) { - return ExtractListSamples( "io.vd", "layer", measurement); + return ExtractSamplesWithLabel( "io.vd", "layer", measurement); } -TrainVectorBase::ListSamples -TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement) +TrainVectorBase::SamplesWithLabel +TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement) { if(GetClassifierCategory() == Supervised) { - ListSamples tmpListSamples; - ListSamples validationListSamples = ExtractListSamples( "valid.vd", "valid.layer", measurement ); + SamplesWithLabel tmpSamplesWithLabel; + SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement ); //Test the input validation set size - if( validationListSamples.labeledListSample->Size() != 0 ) + if( validationSamplesWithLabel.labeledListSample->Size() != 0 ) { - tmpListSamples.listSample = validationListSamples.listSample; - tmpListSamples.labeledListSample = validationListSamples.labeledListSample; + tmpSamplesWithLabel.listSample = validationSamplesWithLabel.listSample; + tmpSamplesWithLabel.labeledListSample = validationSamplesWithLabel.labeledListSample; } else { otbAppLogWARNING( "The validation set is empty. The performance estimation is done using the input training set in this case." ); - tmpListSamples.listSample = trainingListSamples.listSample; - tmpListSamples.labeledListSample = trainingListSamples.labeledListSample; + tmpSamplesWithLabel.listSample = m_trainingSamplesWithLabel.listSample; + tmpSamplesWithLabel.labeledListSample = m_trainingSamplesWithLabel.labeledListSample; } - return tmpListSamples; + return tmpSamplesWithLabel; } else { - return trainingListSamples; + return m_trainingSamplesWithLabel; } } -TrainVectorBase::StatisticsMeasurement +TrainVectorBase::ShiftScaleParameters TrainVectorBase::ComputeStatistics(unsigned int nbFeatures) { - StatisticsMeasurement measurement = StatisticsMeasurement(); + ShiftScaleParameters measurement = ShiftScaleParameters(); if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) ) { StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); @@ -234,51 +230,51 @@ TrainVectorBase::ComputeStatistics(unsigned int nbFeatures) } -TrainVectorBase::ListSamples -TrainVectorBase::ExtractListSamples(std::string parameterName, std::string parameterLayer, - const StatisticsMeasurement &measurement) +TrainVectorBase::SamplesWithLabel +TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, + const ShiftScaleParameters &measurement) { - ListSamples listSamples; + SamplesWithLabel samplesWithLabel; if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) ) { ListSampleType::Pointer input = ListSampleType::New(); TargetListSampleType::Pointer target = TargetListSampleType::New(); - input->SetMeasurementVectorSize( featuresInfo.m_NbFeatures ); + input->SetMeasurementVectorSize( m_featuresInfo.m_NbFeatures ); - std::vector<std::string> validFileList = this->GetParameterStringList( parameterName ); - for( unsigned int k = 0; k < validFileList.size(); k++ ) + std::vector<std::string> fileList = this->GetParameterStringList( parameterName ); + for( unsigned int k = 0; k < fileList.size(); k++ ) { - otbAppLogINFO( "Reading validation vector file " << k + 1 << "/" << validFileList.size() ); - ogr::DataSource::Pointer source = ogr::DataSource::New( validFileList[k], ogr::DataSource::Modes::Read ); + otbAppLogINFO( "Reading vector file " << k + 1 << "/" << fileList.size() ); + ogr::DataSource::Pointer source = ogr::DataSource::New( fileList[k], ogr::DataSource::Modes::Read ); ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) ); ogr::Feature feature = layer.ogr().GetNextFeature(); bool goesOn = feature.addr() != 0; if( !goesOn ) { - otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << validFileList[k] + otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << fileList[k] << " is empty, input is skipped." ); continue; } // Check all needed fields are present : // - check class field if we use supervised classification or if class field name is not empty - int cFieldIndex = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedCFieldName.c_str() ); - if( cFieldIndex < 0 && !featuresInfo.m_SelectedCFieldName.empty()) + int cFieldIndex = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedCFieldName.c_str() ); + if( cFieldIndex < 0 && !m_featuresInfo.m_SelectedCFieldName.empty()) { - otbAppLogFATAL( "The field name for class label (" << featuresInfo.m_SelectedCFieldName + otbAppLogFATAL( "The field name for class label (" << m_featuresInfo.m_SelectedCFieldName << ") has not been found in the vector file " - << validFileList[k] ); + << fileList[k] ); } // - check feature fields - std::vector<int> featureFieldIndex( featuresInfo.m_NbFeatures, -1 ); - for( unsigned int i = 0; i < featuresInfo.m_NbFeatures; i++ ) + std::vector<int> featureFieldIndex( m_featuresInfo.m_NbFeatures, -1 ); + for( unsigned int i = 0; i < m_featuresInfo.m_NbFeatures; i++ ) { - featureFieldIndex[i] = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedNames[i].c_str() ); + featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedNames[i].c_str() ); if( featureFieldIndex[i] < 0 ) - otbAppLogFATAL( "The field name for feature " << featuresInfo.m_SelectedNames[i] + otbAppLogFATAL( "The field name for feature " << m_featuresInfo.m_SelectedNames[i] << " has not been found in the vector file " - << validFileList[k] ); + << fileList[k] ); } @@ -286,8 +282,8 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param { // Retrieve all the features for each field in the ogr layer. MeasurementType mv; - mv.SetSize( featuresInfo.m_NbFeatures ); - for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx ) + mv.SetSize( m_featuresInfo.m_NbFeatures ); + for( unsigned int idx = 0; idx < m_featuresInfo.m_NbFeatures; ++idx ) mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] ); input->PushBack( mv ); @@ -310,11 +306,11 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param shiftScaleFilter->SetScales( measurement.stddevMeasurementVector ); shiftScaleFilter->Update(); - listSamples.listSample = shiftScaleFilter->GetOutput(); - listSamples.labeledListSample = target; + samplesWithLabel.listSample = shiftScaleFilter->GetOutput(); + samplesWithLabel.labeledListSample = target; } - return listSamples; + return samplesWithLabel; } @@ -322,5 +318,3 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param } #endif - - -- GitLab