Commit efac39f3 authored by Guillaume Pasero's avatar Guillaume Pasero

ENH: refactor TrainImagesClassifier to use LearningApplicationBase

parent 86927797
......@@ -14,15 +14,95 @@
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbLearningApplicationBase.h"
#include "otbWrapperApplicationFactory.h"
#include "otbTrainImagesClassifier.h"
#include "otbListSampleGenerator.h"
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
// Validation
#include "otbConfusionMatrixCalculator.h"
#include "itkTimeProbe.h"
#include "otbStandardFilterWatcher.h"
// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include "otbConcatenateSampleListFilter.h"
// Balancing ListSample
#include "otbListSampleToBalancedListSampleFilter.h"
// VectorData projection filter
// Extract a ROI of the vectordata
#include "otbVectorDataIntoImageProjectionFilter.h"
// Elevation handler
#include "otbWrapperElevationParametersHandler.h"
namespace otb
{
namespace Wrapper
{
void TrainImagesClassifier::DoInit()
class TrainImagesClassifier: public LearningApplicationBase<float,int>
{
public:
/** Standard class typedefs. */
typedef TrainImagesClassifier Self;
typedef LearningApplicationBase<float,int> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self)
itkTypeMacro(TrainImagesClassifier, otb::Wrapper::LearningApplicationBase)
typedef Superclass::SampleType SampleType;
typedef Superclass::ListSampleType ListSampleType;
typedef Superclass::TargetSampleType TargetSampleType;
typedef Superclass::TargetListSampleType TargetListSampleType;
typedef Superclass::SampleImageType SampleImageType;
typedef SampleImageType::PixelType PixelType;
// SampleList manipulation
typedef otb::ListSampleGenerator<SampleImageType, VectorDataType> ListSampleGeneratorType;
typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
typedef otb::Statistics::ConcatenateSampleListFilter<TargetListSampleType> ConcatenateLabelListSampleFilterType;
// Statistic XML file Reader
typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
// Enhance List Sample typedef otb::Statistics::ListSampleToBalancedListSampleFilter<ListSampleType, LabelListSampleType> BalancingListSampleFilterType;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
// Estimate performance on validation sample
typedef otb::ConfusionMatrixCalculator<TargetListSampleType, TargetListSampleType> ConfusionMatrixCalculatorType;
typedef ConfusionMatrixCalculatorType::ConfusionMatrixType ConfusionMatrixType;
typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType;
typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType;
// VectorData projection filter
typedef otb::VectorDataProjectionFilter<VectorDataType, VectorDataType> VectorDataProjectionFilterType;
// Extract ROI
typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, SampleImageType> VectorDataReprojectionType;
protected:
//using Superclass::AddParameter;
//friend void InitSVMParams(TrainImagesClassifier & app);
private:
void DoInit()
{
SetName("TrainImagesClassifier");
SetDescription(
......@@ -46,8 +126,6 @@ void TrainImagesClassifier::DoInit()
SetDocAuthors("OTB-Team");
SetDocSeeAlso("OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html ");
AddDocTag(Tags::Learning);
//Group IO
AddParameter(ParameterType_Group, "io", "Input and output data");
SetParameterDescription("io", "This group of parameters allows to set input and output data.");
......@@ -101,24 +179,7 @@ void TrainImagesClassifier::DoInit()
SetParameterDescription("sample.vfn", "Name of the field used to discriminate class labels in the input vector data files.");
SetParameterString("sample.vfn", "Class");
AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
//Group LibSVM
#ifdef OTB_USE_LIBSVM
InitLibSVMParams();
#endif
#ifdef OTB_USE_OPENCV
InitSVMParams();
InitBoostParams();
InitDecisionTreeParams();
InitGradientBoostedTreeParams();
InitNeuralNetworkParams();
InitNormalBayesParams();
InitRandomForestsParams();
InitKNNParams();
#endif
Superclass::DoInit();
AddRANDParameter();
// Doc example parameter settings
......@@ -136,15 +197,14 @@ void TrainImagesClassifier::DoInit()
SetDocExampleParameterValue("classifier.libsvm.opt", "false");
SetDocExampleParameterValue("io.out", "svmModelQB1.txt");
SetDocExampleParameterValue("io.confmatout", "svmConfusionMatrixQB1.csv");
}
}
void TrainImagesClassifier::DoUpdateParameters()
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
}
void TrainImagesClassifier::LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc)
void LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc)
{
ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix();
......@@ -215,24 +275,7 @@ void TrainImagesClassifier::LogConfusionMatrix(ConfusionMatrixCalculatorType* co
otbAppLogINFO("Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str());
}
void TrainImagesClassifier::Classify(ListSampleType::Pointer validationListSample, LabelListSampleType::Pointer predictedList)
{
//Classification
ModelPointerType model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("io.out"),
MachineLearningModelFactoryType::ReadMode);
if (model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("io.out"));
}
model->Load(GetParameterString("io.out"));
model->SetInputListSample(validationListSample);
model->SetTargetListSample(predictedList);
model->PredictAll();
}
void TrainImagesClassifier::DoExecute()
void DoExecute()
{
GetLogger()->Debug("Entering DoExecute\n");
//Create training and validation for list samples and label list samples
......@@ -243,8 +286,8 @@ void TrainImagesClassifier::DoExecute()
ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateValidationSamples = ConcatenateListSampleFilterType::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
SampleType meanMeasurementVector;
SampleType stddevMeasurementVector;
//--------------------------
// Load measurements from images
......@@ -358,7 +401,7 @@ void TrainImagesClassifier::DoExecute()
}
ListSampleType::Pointer listSample;
LabelListSampleType::Pointer labelListSample;
TargetListSampleType::Pointer labelListSample;
//--------------------------
// Balancing training sample (if needed)
// if (IsParameterEnabled("sample.b"))
......@@ -384,9 +427,9 @@ void TrainImagesClassifier::DoExecute()
//--------------------------
// Split the data set into training/validation set
ListSampleType::Pointer trainingListSample = listSample;
LabelListSampleType::Pointer trainingLabeledListSample = labelListSample;
TargetListSampleType::Pointer trainingLabeledListSample = labelListSample;
LabelListSampleType::Pointer validationLabeledListSample = concatenateValidationLabels->GetOutput();
TargetListSampleType::Pointer validationLabeledListSample = concatenateValidationLabels->GetOutput();
otbAppLogINFO("Size of training set: " << trainingListSample->Size());
otbAppLogINFO("Size of validation set: " << validationListSample->Size());
otbAppLogINFO("Size of labeled training set: " << trainingLabeledListSample->Size());
......@@ -395,88 +438,14 @@ void TrainImagesClassifier::DoExecute()
//--------------------------
// Estimate model
//--------------------------
LabelListSampleType::Pointer predictedList = LabelListSampleType::New();
const std::string classifierType = GetParameterString("classifier");
if (classifierType == "libsvm")
{
#ifdef OTB_USE_LIBSVM
TrainLibSVM(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration.");
#endif
}
else if (classifierType == "svm")
{
#ifdef OTB_USE_OPENCV
TrainSVM(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "boost")
{
#ifdef OTB_USE_OPENCV
TrainBoost(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "dt")
{
#ifdef OTB_USE_OPENCV
TrainDecisionTree(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "gbt")
{
#ifdef OTB_USE_OPENCV
TrainGradientBoostedTree(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "ann")
{
#ifdef OTB_USE_OPENCV
TrainNeuralNetwork(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "bayes")
{
#ifdef OTB_USE_OPENCV
TrainNormalBayes(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "rf")
{
#ifdef OTB_USE_OPENCV
TrainRandomForests(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (classifierType == "knn")
{
#ifdef OTB_USE_OPENCV
TrainKNN(trainingListSample, trainingLabeledListSample);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
this->Train(trainingListSample,trainingLabeledListSample,GetParameterString("io.out"));
//--------------------------
// Performances estimation
//--------------------------
TargetListSampleType::Pointer predictedList = TargetListSampleType::New();
ListSampleType::Pointer performanceListSample=ListSampleType::New();
LabelListSampleType::Pointer performanceLabeledListSample=LabelListSampleType::New();
TargetListSampleType::Pointer performanceLabeledListSample=TargetListSampleType::New();
//Test the input validation set size
if(validationLabeledListSample->Size() != 0)
......@@ -491,7 +460,7 @@ void TrainImagesClassifier::DoExecute()
performanceLabeledListSample = trainingLabeledListSample;
}
Classify(performanceListSample, predictedList);
this->Classify(performanceListSample, predictedList, GetParameterString("io.out"));
ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
......@@ -605,11 +574,12 @@ void TrainImagesClassifier::DoExecute()
} // END if (this->HasValue("io.confmatout"))
// TODO: implement hyperplane distance classifier and performance validation (cf. object detection) ?
}
VectorDataReprojectionType::Pointer vdreproj;
};
}
}
} // end namespace Wrapper
} // end namespace otb
OTB_APPLICATION_EXPORT(otb::Wrapper::TrainImagesClassifier)
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbConfigure.h"
#include "otbWrapperApplicationFactory.h"
#include <iostream>
//Image
#include "otbVectorImage.h"
#include "otbVectorData.h"
#include "otbListSampleGenerator.h"
// ListSample
#include "itkVariableLengthVector.h"
//Estimator
#include "otbMachineLearningModelFactory.h"
#ifdef OTB_USE_OPENCV
# include "otbKNearestNeighborsMachineLearningModel.h"
# include "otbRandomForestsMachineLearningModel.h"
# include "otbSVMMachineLearningModel.h"
# include "otbBoostMachineLearningModel.h"
# include "otbDecisionTreeMachineLearningModel.h"
# include "otbGradientBoostedTreeMachineLearningModel.h"
# include "otbNormalBayesMachineLearningModel.h"
# include "otbNeuralNetworkMachineLearningModel.h"
#endif
#ifdef OTB_USE_LIBSVM
#include "otbLibSVMMachineLearningModel.h"
#endif
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
// Validation
#include "otbConfusionMatrixCalculator.h"
#include "itkTimeProbe.h"
#include "otbStandardFilterWatcher.h"
// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include "otbConcatenateSampleListFilter.h"
// Balancing ListSample
#include "otbListSampleToBalancedListSampleFilter.h"
// VectorData projection filter
// Extract a ROI of the vectordata
#include "otbVectorDataIntoImageProjectionFilter.h"
// Elevation handler
#include "otbWrapperElevationParametersHandler.h"
namespace otb
{
namespace Wrapper
{
class TrainImagesClassifier: public Application
{
public:
/** Standard class typedefs. */
typedef TrainImagesClassifier Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self)
itkTypeMacro(TrainImagesClassifier, otb::Application)
typedef FloatVectorImageType::PixelType PixelType;
typedef FloatVectorImageType::InternalPixelType InternalPixelType;
// Training vectordata
typedef itk::VariableLengthVector<InternalPixelType> MeasurementType;
// SampleList manipulation
typedef otb::ListSampleGenerator<FloatVectorImageType, VectorDataType> ListSampleGeneratorType;
typedef ListSampleGeneratorType::ListSampleType ListSampleType;
typedef ListSampleGeneratorType::LabelType LabelType;
typedef ListSampleGeneratorType::ListLabelType LabelListSampleType;
typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
typedef otb::Statistics::ConcatenateSampleListFilter<LabelListSampleType> ConcatenateLabelListSampleFilterType;
// Statistic XML file Reader
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
// Enhance List Sample typedef otb::Statistics::ListSampleToBalancedListSampleFilter<ListSampleType, LabelListSampleType> BalancingListSampleFilterType;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
// Machine Learning models
typedef otb::MachineLearningModelFactory<InternalPixelType, ListSampleGeneratorType::ClassLabelType> MachineLearningModelFactoryType;
typedef MachineLearningModelFactoryType::MachineLearningModelTypePointer ModelPointerType;
#ifdef OTB_USE_OPENCV
typedef otb::RandomForestsMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> RandomForestType;
typedef otb::KNearestNeighborsMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> KNNType;
typedef otb::SVMMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> SVMType;
typedef otb::BoostMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> BoostType;
typedef otb::DecisionTreeMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> DecisionTreeType;
typedef otb::GradientBoostedTreeMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> GradientBoostedTreeType;
typedef otb::NeuralNetworkMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> NeuralNetworkType;
typedef otb::NormalBayesMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> NormalBayesType;
#endif
#ifdef OTB_USE_LIBSVM
typedef otb::LibSVMMachineLearningModel<InternalPixelType, ListSampleGeneratorType::ClassLabelType> LibSVMType;
#endif
// Estimate performance on validation sample
typedef otb::ConfusionMatrixCalculator<LabelListSampleType, LabelListSampleType> ConfusionMatrixCalculatorType;
typedef ConfusionMatrixCalculatorType::ConfusionMatrixType ConfusionMatrixType;
typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType;
typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType;
// VectorData projection filter
typedef otb::VectorDataProjectionFilter<VectorDataType, VectorDataType> VectorDataProjectionFilterType;
// Extract ROI
typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, FloatVectorImageType> VectorDataReprojectionType;
protected:
using Superclass::AddParameter;
friend void InitSVMParams(TrainImagesClassifier & app);
private:
void DoInit();
void DoUpdateParameters();
void LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc);
#ifdef OTB_USE_LIBSVM
void InitLibSVMParams();
#endif
#ifdef OTB_USE_OPENCV
void InitBoostParams();
void InitSVMParams();
void InitDecisionTreeParams();
void InitGradientBoostedTreeParams();
void InitNeuralNetworkParams();
void InitNormalBayesParams();
void InitRandomForestsParams();
void InitKNNParams();
#endif
#ifdef OTB_USE_LIBSVM
void TrainLibSVM(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
#endif
#ifdef OTB_USE_OPENCV
void TrainBoost(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainSVM(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainDecisionTree(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainGradientBoostedTree(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainNeuralNetwork(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainNormalBayes(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainRandomForests(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainKNN(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
#endif
void Classify(ListSampleType::Pointer validationListSample, LabelListSampleType::Pointer predictedList);
void DoExecute();
VectorDataReprojectionType::Pointer vdreproj;
};
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment