Skip to content
Snippets Groups Projects
Commit 6977cc45 authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

ENH: Update TrainVectorBase and TrainVector Classifier/Clustering.

Add the ability to specify Supervised or Unsupervised classification
parent 63d69d6c
No related branches found
No related tags found
No related merge requests found
......@@ -45,6 +45,11 @@ public:
typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType;
typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType;
protected :
TrainVectorClassifier() : TrainVectorBase()
{
m_ClassifierCategory = Supervised;
}
private:
void DoTrainInit()
......@@ -78,7 +83,18 @@ private:
// Nothing to do here
}
void DoTrainExecute()
void DoBeforeTrainExecute()
{
// Enforce the need of class field name in supervised mode
featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );
if( featuresInfo.m_SelectedCFieldIdx.empty() && m_ClassifierCategory == Supervised )
{
otbAppLogFATAL( << "No field has been selected for data labelling!" );
}
}
void DoAfterTrainExecute()
{
ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionmatrix( predictedList,
classificationListSamples.labeledListSample );
......@@ -86,6 +102,28 @@ private:
}
ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement)
{
ListSamples performanceSample;
ListSamples validationListSamples = ExtractListSamples( "valid.vd", "valid.layer", measurement );
//Test the input validation set size
if( validationListSamples.labeledListSample->Size() != 0 )
{
performanceSample.listSample = validationListSamples.listSample;
performanceSample.labeledListSample = validationListSamples.labeledListSample;
}
else
{
otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case." );
performanceSample.listSample = trainingListSamples.listSample;
performanceSample.labeledListSample = trainingListSamples.labeledListSample;
}
return performanceSample;
}
ConfusionMatrixCalculatorType::Pointer
ComputeConfusionmatrix(const TargetListSampleType::Pointer &predictedListSample,
const TargetListSampleType::Pointer &performanceLabeledListSample)
......@@ -285,7 +323,6 @@ private:
otbAppLogINFO( "Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str() );
}
};
}
}
......
......@@ -16,18 +16,15 @@
=========================================================================*/
#include "otbTrainVectorBase.h"
// Validation
#include "otbConfusionMatrixCalculator.h"
namespace otb
{
namespace Wrapper
{
class TrainVectorClassifier : public TrainVectorBase
class TrainVectorClustering : public TrainVectorBase
{
public:
typedef TrainVectorClassifier Self;
typedef TrainVectorClustering Self;
typedef TrainVectorBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
......@@ -39,10 +36,31 @@ public:
typedef Superclass::ListSampleType ListSampleType;
typedef Superclass::TargetListSampleType TargetListSampleType;
protected :
TrainVectorClustering() : TrainVectorBase()
{
m_ClassifierCategory = Unsupervised;
}
private:
void DoTrainInit()
{
// Nothing to do here
SetName( "TrainVectorClustering" );
SetDescription( "Train a classifier based on labeled or unlabeled geometries and a list of features to consider." );
SetDocName( "Train Vector Clustering" );
SetDocLongDescription( "This application trains a classifier based on "
"labeled or unlabeled geometries and a list of features to consider for classification." );
SetDocLimitations( " " );
SetDocAuthors( "OTB Team" );
SetDocSeeAlso( " " );
// Doc example parameter settings
SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
SetDocExampleParameterValue( "io.stats", "meanVar.xml" );
SetDocExampleParameterValue( "io.out", "svmModel.svm" );
SetDocExampleParameterValue( "feat", "perimeter area width" );
}
void DoTrainUpdateParameters()
......@@ -50,7 +68,12 @@ private:
// Nothing to do here
}
void DoTrainExecute()
void DoBeforeTrainExecute()
{
// Nothing to do here
}
void DoAfterTrainExecute()
{
// Nothing to do here
}
......
......@@ -48,11 +48,13 @@ bool IsNotAlphaNum(char c)
class TrainVectorBase : public LearningApplicationBase<float, int>
{
public:
/** Standard class typedefs. */
typedef TrainVectorBase Self;
typedef LearningApplicationBase<float, int> Superclass;
typedef itk::SmartPointer <Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkTypeMacro(Self, Superclass)
typedef Superclass::SampleType SampleType;
......@@ -96,29 +98,31 @@ protected:
class FeaturesInfo
{
public:
/** Index for class field */
std::vector<int> m_SelectedCFieldIdx;
/** Selected Index */
std::vector<int> m_SelectedIdx;
/** Index for class field */
std::vector<int> m_SelectedCFieldIdx;
/** Selected class field name */
std::string m_SelectedCFieldName;
/** Selected names */
std::vector <std::string> m_SelectedNames;
unsigned int m_NbFeatures;
FeaturesInfo(std::vector <std::string> fieldNames, std::vector <std::string> cFieldNames,
std::vector<int> selectedIdx, std::vector<int> selectedCFieldIdx)
: m_SelectedIdx( selectedIdx ), m_SelectedCFieldIdx( selectedCFieldIdx )
void SetFieldNames(std::vector <std::string> fieldNames, std::vector<int> selectedIdx)
{
m_SelectedIdx = selectedIdx;
m_NbFeatures = static_cast<unsigned int>(selectedIdx.size());
m_SelectedNames = std::vector<std::string>( m_NbFeatures );
for( unsigned int i = 0; i < m_NbFeatures; ++i )
{
m_SelectedNames[i] = fieldNames[selectedIdx[i]];
}
}
void SetClassFieldNames(std::vector<std::string> cFieldNames, std::vector<int> selectedCFieldIdx)
{
m_SelectedCFieldIdx = selectedCFieldIdx;
// Handle only one class field name, if several are provided only the first one is used.
m_SelectedCFieldName = cFieldNames[selectedCFieldIdx.front()];
}
};
......@@ -126,12 +130,11 @@ protected:
protected:
/**
* Function which extract and store all samples for Training, Classification and Validation.
* Function which extract and store all samples for Training and Classification.
* \param measurement statics measurement (mean/stddev)
* \param featuresInfo information about the features
* \return sample list used for training
*/
virtual void ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
virtual void ExtractAllSamples(const StatisticsMeasurement &measurement);
/**
* Extract the training sample list
......@@ -139,60 +142,50 @@ protected:
* \param featuresInfo information about the features
* \return sample list used for training
*/
virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement);
/**
* Extract the validation sample list
* \param measurement statics measurement (mean/stddev)
* \param featuresInfo information about the features
* \return sample list used for validation
*/
virtual ListSamples ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
/**
* Extract the sample list classification
* Extract classification the sample list
* \param measurement statics measurement (mean/stddev)
* \param featuresInfo information about the features
* \return sample list used for classification
*/
virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement);
/** Extract samples from input file for corresponding field name
*
* \param parameterName the name of the input file option in the input application parameters
* \param parameterLayer the name of the layer option in the input application parameters
* \param measurement statics measurement (mean/stddev)
* \param nbFeatures the number of features.
* \return the list of samples and their corresponding labels.
*/
ListSamples
ExtractListSamples(std::string parameterName, std::string parameterLayer, const StatisticsMeasurement &measurement);
/**
* Retrieve statistics mean and standard deviation if input statistics are provided.
* Otherwise mean is set to 0 and standard deviation to 1 for each Features.
* \param nbFeatures
*/
StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures);
ListSamples trainingListSamples;
ListSamples validationListSamples;
ListSamples classificationListSamples;
TargetListSampleType::Pointer predictedList;
FeaturesInfo featuresInfo;
private:
virtual void DoTrainInit() = 0;
virtual void DoTrainExecute() = 0;
virtual void DoBeforeTrainExecute() = 0;
virtual void DoAfterTrainExecute() = 0;
virtual void DoTrainUpdateParameters() = 0;
void DoInit();
void DoUpdateParameters();
void DoExecute();
/** Extract samples from input file for corresponding field name
*
* \param parameterName the name of the input file option in the input application parameters
* \param parameterLayer the name of the layer option in the input application parameters
* \param measurement statics measurement (mean/stddev)
* \param nbFeatures the number of features.
* \return the list of samples and their corresponding labels.
*/
ListSamples ExtractListSamples(std::string parameterName, std::string parameterLayer,
const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
ListSamples ExtractClassificationListSamples(ListSamples &validationListSamples, ListSamples &trainingListSamples);
/**
* Retrieve statistics mean and standard deviation if input statistics are provided.
* Otherwise mean is set to 0 and standard deviation to 1 for each Features.
* \param nbFeatures
*/
StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures);
void DoInit() ITK_OVERRIDE;
void DoUpdateParameters() ITK_OVERRIDE;
void DoExecute() ITK_OVERRIDE;
};
......
......@@ -56,7 +56,10 @@ void TrainVectorBase::DoInit()
MandatoryOff( "layer" );
SetDefaultParameterInt( "layer", 0 );
//Can be in both Supervised and Unsupervised ?
AddParameter(ParameterType_ListView, "feat", "Field names for training features.");
SetParameterDescription("feat","List of field names in the input vector data to be used as features for training.");
// Add validation data used to compute confusion matrix or contingence table
AddParameter( ParameterType_Group, "valid", "Validation data" );
SetParameterDescription( "valid", "This group of parameters defines validation data." );
......@@ -70,14 +73,13 @@ void TrainVectorBase::DoInit()
MandatoryOff( "valid.layer" );
SetDefaultParameterInt( "valid.layer", 0 );
AddParameter(ParameterType_ListView, "feat", "Field names for training features.");
SetParameterDescription("feat","List of field names in the input vector data to be used as features for training.");
// Add class field if we used validation
AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision");
SetParameterDescription("cfield","Field containing the class id for supervision. "
"Only geometries with this field available will be taken into account.");
SetListViewSingleSelectionMode("cfield",true);
// Add parameters for the classifier choice
Superclass::DoInit();
......@@ -92,7 +94,7 @@ void TrainVectorBase::DoUpdateParameters()
{
std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read );
ogr::Layer layer = ogrDS->GetLayer( this->GetParameterInt( "layer" ) );
ogr::Layer layer = ogrDS->GetLayer( static_cast<size_t>( this->GetParameterInt( "layer" ) ) );
ogr::Feature feature = layer.ogr().GetNextFeature();
ClearChoices( "feat" );
......@@ -109,12 +111,12 @@ void TrainVectorBase::DoUpdateParameters()
if( fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) || fieldType == OFTReal )
{
std::string tmpKey = "feat." + key.substr( 0, end - key.begin() );
std::string tmpKey = "feat." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
AddChoice( tmpKey, item );
}
if( fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) )
{
std::string tmpKey = "cfield." + key.substr( 0, end - key.begin() );
std::string tmpKey = "cfield." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
AddChoice( tmpKey, item );
}
}
......@@ -125,12 +127,9 @@ void TrainVectorBase::DoUpdateParameters()
void TrainVectorBase::DoExecute()
{
typedef int LabelPixelType;
typedef itk::FixedArray<LabelPixelType, 1> LabelSampleType;
typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType;
DoBeforeTrainExecute();
FeaturesInfo featuresInfo( GetChoiceNames( "feat" ), GetChoiceNames( "cfield" ), GetSelectedItems( "feat" ),
GetSelectedItems( "cfield" ) );
featuresInfo.SetFieldNames( GetChoiceNames( "feat" ), GetSelectedItems( "feat" ));
// Check input parameters
if( featuresInfo.m_SelectedIdx.empty() )
......@@ -138,64 +137,34 @@ void TrainVectorBase::DoExecute()
otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
}
// Todo only Log warning and set CFieldName to 0, 1, 2, 3... (default behavior)
if( featuresInfo.m_SelectedCFieldIdx.empty() )
{
otbAppLogFATAL( << "No field has been selected for data labelling!" );
}
StatisticsMeasurement measurement = ComputeStatistics( featuresInfo.m_NbFeatures );
ExtractSamples(measurement, featuresInfo);
ExtractAllSamples( measurement );
this->Train( trainingListSamples.listSample, trainingListSamples.labeledListSample, GetParameterString( "io.out" ) );
predictedList = TargetListSampleType::New();
this->Classify( classificationListSamples.listSample, predictedList, GetParameterString( "io.out" ) );
DoTrainExecute();
DoAfterTrainExecute();
}
void TrainVectorBase::ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
void TrainVectorBase::ExtractAllSamples(const StatisticsMeasurement &measurement)
{
trainingListSamples = ExtractTrainingListSamples(measurement, featuresInfo);
validationListSamples = ExtractValidationListSamples(measurement, featuresInfo);
classificationListSamples = ExtractClassificationListSamples(measurement, featuresInfo);
trainingListSamples = ExtractTrainingListSamples(measurement);
classificationListSamples = ExtractClassificationListSamples(measurement);
}
TrainVectorBase::ListSamples
TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement)
{
return ExtractListSamples( "io.vd", "layer", measurement, featuresInfo );
return ExtractListSamples( "io.vd", "layer", measurement);
}
TrainVectorBase::ListSamples
TrainVectorBase::ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
{
return ExtractListSamples( "valid.vd", "valid.layer", measurement, featuresInfo );
}
TrainVectorBase::ListSamples
TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &itkNotUsed(measurement))
{
ListSamples performanceSample;
//Test the input validation set size
if( validationListSamples.labeledListSample->Size() != 0 )
{
performanceSample.listSample = validationListSamples.listSample;
performanceSample.labeledListSample = validationListSamples.labeledListSample;
}
else
{
otbAppLogWARNING(
"The validation set is empty. The performance estimation is done using the input training set in this case." );
performanceSample.listSample = trainingListSamples.listSample;
performanceSample.labeledListSample = trainingListSamples.labeledListSample;
}
return performanceSample;
return trainingListSamples;
}
......@@ -224,7 +193,7 @@ TrainVectorBase::ComputeStatistics(unsigned int nbFeatures)
TrainVectorBase::ListSamples
TrainVectorBase::ExtractListSamples(std::string parameterName, std::string parameterLayer,
const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
const StatisticsMeasurement &measurement)
{
ListSamples listSamples;
if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) )
......@@ -249,12 +218,15 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param
}
// Check all needed fields are present :
// - check class field
// - check class field if we use supervised classification or if class field name is not empty
int cFieldIndex = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedCFieldName.c_str() );
if( cFieldIndex < 0 )
if( cFieldIndex < 0 && !featuresInfo.m_SelectedCFieldName.empty())
{
otbAppLogFATAL( "The field name for class label (" << featuresInfo.m_SelectedCFieldName
<< ") has not been found in the vector file "
<< validFileList[k] );
}
// - check feature fields
std::vector<int> featureFieldIndex( featuresInfo.m_NbFeatures, -1 );
for( unsigned int i = 0; i < featuresInfo.m_NbFeatures; i++ )
......@@ -266,18 +238,22 @@ TrainVectorBase::ExtractListSamples(std::string parameterName, std::string param
<< validFileList[k] );
}
while( goesOn )
{
if( feature.ogr().IsFieldSet( cFieldIndex ) )
{
MeasurementType mv;
mv.SetSize( featuresInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx )
mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] );
// Retrieve all the features for each field in the ogr layer.
MeasurementType mv;
mv.SetSize( featuresInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx )
mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] );
input->PushBack( mv );
input->PushBack( mv );
if( feature.ogr().IsFieldSet( cFieldIndex ) )
target->PushBack( feature.ogr().GetFieldAsInteger( cFieldIndex ) );
}
else
target->PushBack( 0 );
feature = layer.ogr().GetNextFeature();
goesOn = feature.addr() != 0;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment