Skip to content
Snippets Groups Projects
Commit 9b617cfa authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

REFAC: Apply RFC 85 review

parent c5efcd60
No related branches found
No related tags found
No related merge requests found
......@@ -49,44 +49,42 @@ public:
typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType;
typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType;
private:
void DoTrainInit()
protected:
void DoInit()
{
// Nothing to do here
TrainVectorBase::DoInit();
}
void DoTrainUpdateParameters()
void DoUpdateParameters()
{
// Nothing to do here
TrainVectorBase::DoUpdateParameters();
}
void DoBeforeTrainExecute()
void DoExecute()
{
// Enforce the need of class field name in supervised mode
if (GetClassifierCategory() == Supervised)
{
featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
m_featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
if( featuresInfo.m_SelectedCFieldIdx.empty() )
if( m_featuresInfo.m_SelectedCFieldIdx.empty() )
{
otbAppLogFATAL( << "No field has been selected for data labelling!" );
}
}
}
void DoAfterTrainExecute()
{
TrainVectorBase::DoExecute();
if (GetClassifierCategory() == Supervised)
{
ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( predictedList,
classificationListSamples.labeledListSample );
WriteConfusionMatrix( confMatCalc );
}
else
{
// TODO Compute Contingency Table
}
if (GetClassifierCategory() == Supervised)
{
ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( m_predictedList,
m_classificationSamplesWithLabel.labeledListSample );
WriteConfusionMatrix( confMatCalc );
}
else
{
// TODO Compute Contingency Table
}
}
......
......@@ -35,7 +35,7 @@ namespace Wrapper
{
/** \class TrainImagesBase
* \brief Base class for the TrainImagesBaseClassifier and Clustering
* \brief Base class for the TrainImagesClassifier
*
* This class intends to hold common input/output parameters and
* composite application connection for both supervised and unsupervised
......
......@@ -35,11 +35,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
"See complete documentation here "
"\\url{http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html}.\n " );
//MaxNumberOfIterations
AddParameter( ParameterType_Int, "classifier.sharkkm.nbmaxiter",
AddParameter( ParameterType_Int, "classifier.sharkkm.maxiter",
"Maximum number of iteration for the kmeans algorithm." );
SetParameterInt( "classifier.sharkkm.nbmaxiter", 10 );
SetMinimumParameterIntValue( "classifier.sharkkm.nbmaxiter", 0 );
SetParameterDescription( "classifier.sharkkm.nbmaxiter",
SetParameterInt( "classifier.sharkkm.maxiter", 10 );
SetMinimumParameterIntValue( "classifier.sharkkm.maxiter", 0 );
SetParameterDescription( "classifier.sharkkm.maxiter",
"The maximum number of iteration for the kmeans algorithm. 0=unlimited" );
//MaxNumberOfIterations
......@@ -55,7 +55,7 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
{
unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.nbmaxiter" ) ));
unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.maxiter" ) ));
unsigned int k = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.k" ) ));
typename SharkKMeansType::Pointer classifier = SharkKMeansType::New();
......
......@@ -74,7 +74,7 @@ public:
protected:
/** Class used to store statistics Measurment (mean/stddev) */
class StatisticsMeasurement
class ShiftScaleParameters
{
public:
MeasurementType meanMeasurementVector;
......@@ -82,12 +82,12 @@ protected:
};
/** Class used to store a list of sample and the corresponding label */
class ListSamples
class SamplesWithLabel
{
public:
ListSampleType::Pointer listSample;
TargetListSampleType::Pointer labeledListSample;
ListSamples()
SamplesWithLabel()
{
listSample = ListSampleType::New();
labeledListSample = TargetListSampleType::New();
......@@ -137,7 +137,7 @@ protected:
* \param measurement statics measurement (mean/stddev)
* \param featuresInfo information about the features
*/
virtual void ExtractAllSamples(const StatisticsMeasurement &measurement);
virtual void ExtractAllSamples(const ShiftScaleParameters &measurement);
/**
* Extract the training sample list
......@@ -145,7 +145,7 @@ protected:
* \param featuresInfo information about the features
* \return sample list used for training
*/
virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement);
virtual SamplesWithLabel ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement);
/**
* Extract classification the sample list
......@@ -153,7 +153,7 @@ protected:
* \param featuresInfo information about the features
* \return sample list used for classification
*/
virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement);
virtual SamplesWithLabel ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement);
/** Extract samples from input file for corresponding field name
......@@ -164,8 +164,8 @@ protected:
* \param nbFeatures the number of features.
* \return the list of samples and their corresponding labels.
*/
ListSamples
ExtractListSamples(std::string parameterName, std::string parameterLayer, const StatisticsMeasurement &measurement);
SamplesWithLabel
ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer, const ShiftScaleParameters &measurement);
/**
......@@ -173,18 +173,12 @@ protected:
* Otherwise mean is set to 0 and standard deviation to 1 for each Features.
* \param nbFeatures
*/
StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures);
ShiftScaleParameters ComputeStatistics(unsigned int nbFeatures);
ListSamples trainingListSamples;
ListSamples classificationListSamples;
TargetListSampleType::Pointer predictedList;
FeaturesInfo featuresInfo;
private:
virtual void DoTrainInit() = 0;
virtual void DoBeforeTrainExecute() = 0;
virtual void DoAfterTrainExecute() = 0;
virtual void DoTrainUpdateParameters() = 0;
SamplesWithLabel m_trainingSamplesWithLabel;
SamplesWithLabel m_classificationSamplesWithLabel;
TargetListSampleType::Pointer m_predictedList;
FeaturesInfo m_featuresInfo;
void DoInit() ITK_OVERRIDE;
void DoUpdateParameters() ITK_OVERRIDE;
......@@ -200,4 +194,3 @@ private:
#endif
#endif
......@@ -102,7 +102,7 @@ void TrainVectorBase::DoInit()
AddRANDParameter();
DoTrainInit();
DoInit();
}
void TrainVectorBase::DoUpdateParameters()
......@@ -142,79 +142,75 @@ void TrainVectorBase::DoUpdateParameters()
}
}
DoTrainUpdateParameters();
DoUpdateParameters();
}
void TrainVectorBase::DoExecute()
{
DoBeforeTrainExecute();
featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" ));
m_featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" ));
// Check input parameters
if( featuresInfo.m_SelectedIdx.empty() )
if( m_featuresInfo.m_SelectedIdx.empty() )
{
otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
}
StatisticsMeasurement measurement = ComputeStatistics( featuresInfo.m_NbFeatures );
ShiftScaleParameters measurement = ComputeStatistics( m_featuresInfo.m_NbFeatures );
ExtractAllSamples( measurement );
this->Train( trainingListSamples.listSample, trainingListSamples.labeledListSample, GetParameterString( "io.out" ) );
predictedList = TargetListSampleType::New();
this->Classify( classificationListSamples.listSample, predictedList, GetParameterString( "io.out" ) );
this->Train( m_trainingSamplesWithLabel.listSample, m_trainingSamplesWithLabel.labeledListSample, GetParameterString( "io.out" ) );
DoAfterTrainExecute();
m_predictedList = TargetListSampleType::New();
this->Classify( m_classificationSamplesWithLabel.listSample, m_predictedList, GetParameterString( "io.out" ) );
}
void TrainVectorBase::ExtractAllSamples(const StatisticsMeasurement &measurement)
void TrainVectorBase::ExtractAllSamples(const ShiftScaleParameters &measurement)
{
trainingListSamples = ExtractTrainingListSamples(measurement);
classificationListSamples = ExtractClassificationListSamples(measurement);
m_trainingSamplesWithLabel = ExtractTrainingSamplesWithLabel(measurement);
m_classificationSamplesWithLabel = ExtractClassificationSamplesWithLabel(measurement);
}
TrainVectorBase::ListSamples
TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement)
TrainVectorBase::SamplesWithLabel
TrainVectorBase::ExtractTrainingSamplesWithLabel(const ShiftScaleParameters &measurement)
{
return ExtractListSamples( "io.vd", "layer", measurement);
return ExtractSamplesWithLabel( "io.vd", "layer", measurement);
}
TrainVectorBase::ListSamples
TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement)
TrainVectorBase::SamplesWithLabel
TrainVectorBase::ExtractClassificationSamplesWithLabel(const ShiftScaleParameters &measurement)
{
if(GetClassifierCategory() == Supervised)
{
ListSamples tmpListSamples;
ListSamples validationListSamples = ExtractListSamples( "valid.vd", "valid.layer", measurement );
SamplesWithLabel tmpSamplesWithLabel;
SamplesWithLabel validationSamplesWithLabel = ExtractSamplesWithLabel( "valid.vd", "valid.layer", measurement );
//Test the input validation set size
if( validationListSamples.labeledListSample->Size() != 0 )
if( validationSamplesWithLabel.labeledListSample->Size() != 0 )
{
tmpListSamples.listSample = validationListSamples.listSample;
tmpListSamples.labeledListSample = validationListSamples.labeledListSample;
tmpSamplesWithLabel.listSample = validationSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = validationSamplesWithLabel.labeledListSample;
}
else
{
otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case." );
tmpListSamples.listSample = trainingListSamples.listSample;
tmpListSamples.labeledListSample = trainingListSamples.labeledListSample;
tmpSamplesWithLabel.listSample = m_trainingSamplesWithLabel.listSample;
tmpSamplesWithLabel.labeledListSample = m_trainingSamplesWithLabel.labeledListSample;
}
return tmpListSamples;
return tmpSamplesWithLabel;
}
else
{
return trainingListSamples;
return m_trainingSamplesWithLabel;
}
}
TrainVectorBase::StatisticsMeasurement
TrainVectorBase::ShiftScaleParameters
TrainVectorBase::ComputeStatistics(unsigned int nbFeatures)
{
StatisticsMeasurement measurement = StatisticsMeasurement();
ShiftScaleParameters measurement = ShiftScaleParameters();
if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) )
{
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
......@@ -234,51 +230,51 @@ TrainVectorBase::ComputeStatistics(unsigned int nbFeatures)
}
TrainVectorBase::ListSamples
TrainVectorBase::ExtractListSamples(std::string parameterName, std::string parameterLayer,
const StatisticsMeasurement &measurement)
TrainVectorBase::SamplesWithLabel
TrainVectorBase::ExtractSamplesWithLabel(std::string parameterName, std::string parameterLayer,
const ShiftScaleParameters &measurement)
{
ListSamples listSamples;
SamplesWithLabel samplesWithLabel;
if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) )
{
ListSampleType::Pointer input = ListSampleType::New();
TargetListSampleType::Pointer target = TargetListSampleType::New();
input->SetMeasurementVectorSize( featuresInfo.m_NbFeatures );
input->SetMeasurementVectorSize( m_featuresInfo.m_NbFeatures );
std::vector<std::string> validFileList = this->GetParameterStringList( parameterName );
for( unsigned int k = 0; k < validFileList.size(); k++ )
std::vector<std::string> fileList = this->GetParameterStringList( parameterName );
for( unsigned int k = 0; k < fileList.size(); k++ )
{
otbAppLogINFO( "Reading validation vector file " << k + 1 << "/" << validFileList.size() );
ogr::DataSource::Pointer source = ogr::DataSource::New( validFileList[k], ogr::DataSource::Modes::Read );
otbAppLogINFO( "Reading vector file " << k + 1 << "/" << fileList.size() );
ogr::DataSource::Pointer source = ogr::DataSource::New( fileList[k], ogr::DataSource::Modes::Read );
ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) );
ogr::Feature feature = layer.ogr().GetNextFeature();
bool goesOn = feature.addr() != 0;
if( !goesOn )
{
otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << validFileList[k]
otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << fileList[k]
<< " is empty, input is skipped." );
continue;
}
// Check all needed fields are present :
// - check class field if we use supervised classification or if class field name is not empty
int cFieldIndex = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 && !featuresInfo.m_SelectedCFieldName.empty())
int cFieldIndex = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 && !m_featuresInfo.m_SelectedCFieldName.empty())
{
otbAppLogFATAL( "The field name for class label (" << featuresInfo.m_SelectedCFieldName
otbAppLogFATAL( "The field name for class label (" << m_featuresInfo.m_SelectedCFieldName
<< ") has not been found in the vector file "
<< validFileList[k] );
<< fileList[k] );
}
// - check feature fields
std::vector<int> featureFieldIndex( featuresInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < featuresInfo.m_NbFeatures; i++ )
std::vector<int> featureFieldIndex( m_featuresInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < m_featuresInfo.m_NbFeatures; i++ )
{
featureFieldIndex[i] = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedNames[i].c_str() );
featureFieldIndex[i] = feature.ogr().GetFieldIndex( m_featuresInfo.m_SelectedNames[i].c_str() );
if( featureFieldIndex[i] < 0 )
otbAppLogFATAL( "The field name for feature " << featuresInfo.m_SelectedNames[i]
otbAppLogFATAL( "The field name for feature " << m_featuresInfo.m_SelectedNames[i]
<< " has not been found in the vector file "
<< validFileList[k] );
<< fileList[k] );
}
......@@ -286,8 +282,8 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param
{
// Retrieve all the features for each field in the ogr layer.
MeasurementType mv;
mv.SetSize( featuresInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx )
mv.SetSize( m_featuresInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < m_featuresInfo.m_NbFeatures; ++idx )
mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] );
input->PushBack( mv );
......@@ -310,11 +306,11 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param
shiftScaleFilter->SetScales( measurement.stddevMeasurementVector );
shiftScaleFilter->Update();
listSamples.listSample = shiftScaleFilter->GetOutput();
listSamples.labeledListSample = target;
samplesWithLabel.listSample = shiftScaleFilter->GetOutput();
samplesWithLabel.labeledListSample = target;
}
return listSamples;
return samplesWithLabel;
}
......@@ -322,5 +318,3 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param
}
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment