Skip to content
Snippets Groups Projects
Commit 21587e94 authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

REFAC: Correctly rename TrainVectorBase class members.

parent 1b92ee2f
No related branches found
No related tags found
No related merge requests found
...@@ -65,9 +65,9 @@ protected: ...@@ -65,9 +65,9 @@ protected:
// Enforce the need of class field name in supervised mode // Enforce the need of class field name in supervised mode
if (GetClassifierCategory() == Supervised) if (GetClassifierCategory() == Supervised)
{ {
m_featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); m_FeaturesInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
if( m_featuresInfo.m_SelectedCFieldIdx.empty() ) if( m_FeaturesInfo.m_SelectedCFieldIdx.empty() )
{ {
otbAppLogFATAL( << "No field has been selected for data labelling!" ); otbAppLogFATAL( << "No field has been selected for data labelling!" );
} }
...@@ -77,8 +77,8 @@ protected: ...@@ -77,8 +77,8 @@ protected:
if (GetClassifierCategory() == Supervised) if (GetClassifierCategory() == Supervised)
{ {
ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( m_predictedList, ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( m_PredictedList,
m_classificationSamplesWithLabel.labeledListSample ); m_ClassificationSamplesWithLabel.labeledListSample );
WriteConfusionMatrix( confMatCalc ); WriteConfusionMatrix( confMatCalc );
} }
else else
......
...@@ -175,10 +175,10 @@ protected: ...@@ -175,10 +175,10 @@ protected:
*/ */
ShiftScaleParameters ComputeStatistics(unsigned int nbFeatures); ShiftScaleParameters ComputeStatistics(unsigned int nbFeatures);
SamplesWithLabel m_trainingSamplesWithLabel; SamplesWithLabel m_TrainingSamplesWithLabel;
SamplesWithLabel m_classificationSamplesWithLabel; SamplesWithLabel m_ClassificationSamplesWithLabel;
TargetListSampleType::Pointer m_predictedList; TargetListSampleType::Pointer m_PredictedList;
FeaturesInfo m_featuresInfo; FeaturesInfo m_FeaturesInfo;
void DoInit() ITK_OVERRIDE; void DoInit() ITK_OVERRIDE;
void DoUpdateParameters() ITK_OVERRIDE; void DoUpdateParameters() ITK_OVERRIDE;
......
...@@ -143,28 +143,28 @@ void TrainVectorBase::DoUpdateParameters() ...@@ -143,28 +143,28 @@ void TrainVectorBase::DoUpdateParameters()
void TrainVectorBase::DoExecute() void TrainVectorBase::DoExecute()
{ {
m_featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" )); m_FeaturesInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" ));
// Check input parameters // Check input parameters
if( m_featuresInfo.m_SelectedIdx.empty() ) if( m_FeaturesInfo.m_SelectedIdx.empty() )
{ {
otbAppLogFATAL( << "No features have been selected to train the classifier on!" ); otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
} }
ShiftScaleParameters measurement = ComputeStatistics( m_featuresInfo.m_NbFeatures ); ShiftScaleParameters measurement = ComputeStatistics( m_FeaturesInfo.m_NbFeatures );
ExtractAllSamples( measurement ); ExtractAllSamples( measurement );
this->Train( m_trainingSamplesWithLabel.listSample, m_trainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) ); this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) );
m_predictedList = TargetListSampleType::New(); m_PredictedList = TargetListSampleType::New();
this->Classify( m_classificationSamplesWithLabel.listSample, m_predictedList, GetParameterString( "io.out" ) ); this->Classify( m_ClassificationSamplesWithLabel.listSample, m_PredictedList, GetParameterString( "io.out" ) );
} }
void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement) void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement)
{ {
m_trainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement); m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
m_classificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement); m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
} }
TrainVectorBase::SamplesWithLabel TrainVectorBase::SamplesWithLabel
...@@ -190,15 +190,15 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter ...@@ -190,15 +190,15 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter
{ {
otbAppLogWARNING( otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case." ); "The validation set is empty. The performance estimation is done using the input training set in this case." );
tmpSamplesWithLabel.listSample = m_trainingSamplesWithLabel.listSample; tmpSamplesWithLabel.listSample = m_TrainingSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = m_trainingSamplesWithLabel.labeledListSample; tmpSamplesWithLabel.labeledListSample = m_TrainingSamplesWithLabel.labeledListSample;
} }
return tmpSamplesWithLabel; return tmpSamplesWithLabel;
} }
else else
{ {
return m_trainingSamplesWithLabel; return m_TrainingSamplesWithLabel;
} }
} }
...@@ -235,7 +235,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string ...@@ -235,7 +235,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
{ {
ListSampleType::Pointer input = ListSampleType::New(); ListSampleType::Pointer input = ListSampleType::New();
TargetListSampleType::Pointer target = TargetListSampleType::New(); TargetListSampleType::Pointer target = TargetListSampleType::New();
input->SetMeasurementVectorSize( m_featuresInfo.m_NbFeatures ); input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures );
std::vector<std::string> fileList = this->GetParameterStringList( parameterName ); std::vector<std::string> fileList = this->GetParameterStringList( parameterName );
for( unsigned int k = 0; k < fileList.size(); k++ ) for( unsigned int k = 0; k < fileList.size(); k++ )
...@@ -254,21 +254,21 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string ...@@ -254,21 +254,21 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
// Check all needed fields are present : // Check all needed fields are present :
// - check class field if we use supervised classification or if class field name is not empty // - check class field if we use supervised classification or if class field name is not empty
int cFieldIndex = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedCFieldName.c_str() ); int cFieldIndex = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 && !m_featuresInfo.m_SelectedCFieldName.empty()) if( cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty())
{ {
otbAppLogFATAL( "The field name for class label (" << m_featuresInfo.m_SelectedCFieldName otbAppLogFATAL( "The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName
<< ") has not been found in the vector file " << ") has not been found in the vector file "
<< fileList[k] ); << fileList[k] );
} }
// - check feature fields // - check feature fields
std::vector<int> featureFieldIndex( m_featuresInfo.m_NbFeatures, -1 ); std::vector<int> featureFieldIndex( m_FeaturesInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < m_featuresInfo.m_NbFeatures; i++ ) for( unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++ )
{ {
featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedNames[i].c_str() ); featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedNames[i].c_str() );
if( featureFieldIndex[i] < 0 ) if( featureFieldIndex[i] < 0 )
otbAppLogFATAL( "The field name for feature " << m_featuresInfo.m_SelectedNames[i] otbAppLogFATAL( "The field name for feature " << m_FeaturesInfo.m_SelectedNames[i]
<< " has not been found in the vector file " << " has not been found in the vector file "
<< fileList[k] ); << fileList[k] );
} }
...@@ -278,8 +278,8 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string ...@@ -278,8 +278,8 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
{ {
// Retrieve all the features for each field in the ogr layer. // Retrieve all the features for each field in the ogr layer.
MeasurementType mv; MeasurementType mv;
mv.SetSize( m_featuresInfo.m_NbFeatures ); mv.SetSize( m_FeaturesInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < m_featuresInfo.m_NbFeatures; ++idx ) for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx )
mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] ); mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] );
input->PushBack( mv ); input->PushBack( mv );
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment