Commit 21587e94 authored by Ludovic Hussonnois's avatar Ludovic Hussonnois

REFAC: Correctly rename TrainVectorBase class members.

parent 1b92ee2f
......@@ -65,9 +65,9 @@ protected:
// Enforce the need of class field name in supervised mode
if (GetClassifierCategory() == Supervised)
{
m_featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
m_FeaturesInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
if( m_featuresInfo.m_SelectedCFieldIdx.empty() )
if( m_FeaturesInfo.m_SelectedCFieldIdx.empty() )
{
otbAppLogFATAL( << "No field has been selected for data labelling!" );
}
......@@ -77,8 +77,8 @@ protected:
if (GetClassifierCategory() == Supervised)
{
ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( m_predictedList,
m_classificationSamplesWithLabel.labeledListSample );
ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( m_PredictedList,
m_ClassificationSamplesWithLabel.labeledListSample );
WriteConfusionMatrix( confMatCalc );
}
else
......
......@@ -175,10 +175,10 @@ protected:
*/
ShiftScaleParameters ComputeStatistics(unsigned int nbFeatures);
SamplesWithLabel m_trainingSamplesWithLabel;
SamplesWithLabel m_classificationSamplesWithLabel;
TargetListSampleType::Pointer m_predictedList;
FeaturesInfo m_featuresInfo;
SamplesWithLabel m_TrainingSamplesWithLabel;
SamplesWithLabel m_ClassificationSamplesWithLabel;
TargetListSampleType::Pointer m_PredictedList;
FeaturesInfo m_FeaturesInfo;
void DoInit() ITK_OVERRIDE;
void DoUpdateParameters() ITK_OVERRIDE;
......
......@@ -143,28 +143,28 @@ void TrainVectorBase::DoUpdateParameters()
void TrainVectorBase::DoExecute()
{
m_featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" ));
m_FeaturesInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" ));
// Check input parameters
if( m_featuresInfo.m_SelectedIdx.empty() )
if( m_FeaturesInfo.m_SelectedIdx.empty() )
{
otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
}
ShiftScaleParameters measurement = ComputeStatistics( m_featuresInfo.m_NbFeatures );
ShiftScaleParameters measurement = ComputeStatistics( m_FeaturesInfo.m_NbFeatures );
ExtractAllSamples( measurement );
this->Train( m_trainingSamplesWithLabel.listSample, m_trainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) );
this->Train( m_TrainingSamplesWithLabel.listSample, m_TrainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) );
m_predictedList = TargetListSampleType::New();
this->Classify( m_classificationSamplesWithLabel.listSample, m_predictedList, GetParameterString( "io.out" ) );
m_PredictedList = TargetListSampleType::New();
this->Classify( m_ClassificationSamplesWithLabel.listSample, m_PredictedList, GetParameterString( "io.out" ) );
}
void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement)
{
m_trainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
m_classificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
m_TrainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
m_ClassificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
}
TrainVectorBase::SamplesWithLabel
......@@ -190,15 +190,15 @@ TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameter
{
otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case." );
tmpSamplesWithLabel.listSample = m_trainingSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = m_trainingSamplesWithLabel.labeledListSample;
tmpSamplesWithLabel.listSample = m_TrainingSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = m_TrainingSamplesWithLabel.labeledListSample;
}
return tmpSamplesWithLabel;
}
else
{
return m_trainingSamplesWithLabel;
return m_TrainingSamplesWithLabel;
}
}
......@@ -235,7 +235,7 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
{
ListSampleType::Pointer input = ListSampleType::New();
TargetListSampleType::Pointer target = TargetListSampleType::New();
input->SetMeasurementVectorSize( m_featuresInfo.m_NbFeatures );
input->SetMeasurementVectorSize( m_FeaturesInfo.m_NbFeatures );
std::vector<std::string> fileList = this->GetParameterStringList( parameterName );
for( unsigned int k = 0; k < fileList.size(); k++ )
......@@ -254,21 +254,21 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
// Check all needed fields are present :
// - check class field if we use supervised classification or if class field name is not empty
int cFieldIndex = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 && !m_featuresInfo.m_SelectedCFieldName.empty())
int cFieldIndex = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 && !m_FeaturesInfo.m_SelectedCFieldName.empty())
{
otbAppLogFATAL( "The field name for class label (" << m_featuresInfo.m_SelectedCFieldName
otbAppLogFATAL( "The field name for class label (" << m_FeaturesInfo.m_SelectedCFieldName
<< ") has not been found in the vector file "
<< fileList[k] );
}
// - check feature fields
std::vector<int> featureFieldIndex( m_featuresInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < m_featuresInfo.m_NbFeatures; i++ )
std::vector<int> featureFieldIndex( m_FeaturesInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < m_FeaturesInfo.m_NbFeatures; i++ )
{
featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedNames[i].c_str() );
featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_FeaturesInfo.m_SelectedNames[i].c_str() );
if( featureFieldIndex[i] < 0 )
otbAppLogFATAL( "The field name for feature " << m_featuresInfo.m_SelectedNames[i]
otbAppLogFATAL( "The field name for feature " << m_FeaturesInfo.m_SelectedNames[i]
<< " has not been found in the vector file "
<< fileList[k] );
}
......@@ -278,8 +278,8 @@ TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string
{
// Retrieve all the features for each field in the ogr layer.
MeasurementType mv;
mv.SetSize( m_featuresInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < m_featuresInfo.m_NbFeatures; ++idx )
mv.SetSize( m_FeaturesInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx )
mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] );
input->PushBack( mv );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment