Skip to content
Snippets Groups Projects
Commit dd29f666 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: new base application for classif and regression

parent 7c130ac4
No related branches found
No related tags found
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbLearningApplicationBase_h
#define __otbLearningApplicationBase_h
#include "otbConfigure.h"
#include "otbWrapperApplicationFactory.h"
#include <iostream>
// ListSample
#include "otbListSampleGenerator.h"
#include "itkVariableLengthVector.h"
//Estimator
#include "otbMachineLearningModelFactory.h"
#ifdef OTB_USE_OPENCV
# include "otbKNearestNeighborsMachineLearningModel.h"
# include "otbRandomForestsMachineLearningModel.h"
# include "otbSVMMachineLearningModel.h"
# include "otbBoostMachineLearningModel.h"
# include "otbDecisionTreeMachineLearningModel.h"
# include "otbGradientBoostedTreeMachineLearningModel.h"
# include "otbNormalBayesMachineLearningModel.h"
# include "otbNeuralNetworkMachineLearningModel.h"
#endif
#ifdef OTB_USE_LIBSVM
#include "otbLibSVMMachineLearningModel.h"
#endif
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
#include "itkTimeProbe.h"
#include "otbStandardFilterWatcher.h"
// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include "otbConcatenateSampleListFilter.h"
// Balancing ListSample
#include "otbListSampleToBalancedListSampleFilter.h"
namespace otb
{
namespace Wrapper
{
template <class TInputValue, class TOutputValue>
class LearningApplicationBase: public Application
{
public:
/** Standard class typedefs. */
typedef LearningApplicationBase Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self)
itkTypeMacro(LearningApplicationBase, otb::Application)
typedef TInputValue InputValueType;
typedef TOutputValue OutputValueType;
typedef otb::VectorImage<InputValueType> SampleImageType;
typedef typename SampleImageType::PixelType PixelType;
// Machine Learning models
typedef otb::MachineLearningModelFactory<
InputValueType, OutputValueType> ModelFactoryType;
typedef typename ModelFactoryType::MachineLearningModelTypePointer ModelPointerType;
typedef typename ModelFactoryType::MachineLearningModel ModelType;
typedef typename ModelType::InputSampleType SampleType;
typedef typename ModelType::InputListSampleType ListSampleType;
typedef typename ModelType::TargetSampleType TargetSampleType;
typedef typename ModelType::TargetListSampleType TargetListSampleType;
// SampleList manipulation
typedef otb::ListSampleGenerator<SampleImageType, VectorDataType> ListSampleGeneratorType;
typedef ListSampleGeneratorType::LabelType LabelType;
typedef ListSampleGeneratorType::ListLabelType LabelListSampleType;
typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
typedef otb::Statistics::ConcatenateSampleListFilter<LabelListSampleType> ConcatenateLabelListSampleFilterType;
// Statistic XML file Reader
typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
// Enhance List Sample typedef otb::Statistics::ListSampleToBalancedListSampleFilter<ListSampleType, LabelListSampleType> BalancingListSampleFilterType;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
#ifdef OTB_USE_OPENCV
typedef otb::RandomForestsMachineLearningModel<InputValueType, OutputValueType> RandomForestType;
typedef otb::KNearestNeighborsMachineLearningModel<InputValueType, OutputValueType> KNNType;
typedef otb::SVMMachineLearningModel<InputValueType, OutputValueType> SVMType;
typedef otb::BoostMachineLearningModel<InputValueType, OutputValueType> BoostType;
typedef otb::DecisionTreeMachineLearningModel<InputValueType, OutputValueType> DecisionTreeType;
typedef otb::GradientBoostedTreeMachineLearningModel<InputValueType, OutputValueType> GradientBoostedTreeType;
typedef otb::NeuralNetworkMachineLearningModel<InputValueType, OutputValueType> NeuralNetworkType;
typedef otb::NormalBayesMachineLearningModel<InputValueType, OutputValueType> NormalBayesType;
#endif
#ifdef OTB_USE_LIBSVM
typedef otb::LibSVMMachineLearningModel<InputValueType, OutputValueType> LibSVMType;
#endif
protected:
using Superclass::AddParameter;
friend void InitSVMParams(LearningApplicationBase & app);
private:
void DoInit();
#ifdef OTB_USE_LIBSVM
void InitLibSVMParams();
#endif
#ifdef OTB_USE_OPENCV
void InitBoostParams();
void InitSVMParams();
void InitDecisionTreeParams();
void InitGradientBoostedTreeParams();
void InitNeuralNetworkParams();
void InitNormalBayesParams();
void InitRandomForestsParams();
void InitKNNParams();
#endif
#ifdef OTB_USE_LIBSVM
void TrainLibSVM(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample);
#endif
#ifdef OTB_USE_OPENCV
void TrainBoost(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainSVM(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainDecisionTree(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainGradientBoostedTree(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainNeuralNetwork(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainNormalBayes(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainRandomForests(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
void TrainKNN(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath);
#endif
void Train(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath, std::string modelName);
void Classify(ListSampleType::Pointer validationListSample, TargetListSampleType::Pointer predictedList, std::string modelPath);
bool m_RegressionFlag;
};
}
}
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbLearningApplicationBase.txx"
#ifdef OTB_USE_OPENCV
#include "otbTrainBoost.txx"
#include "otbTrainDecisionTree.txx"
#include "otbTrainGradientBoostedTree.txx"
#include "otbTrainKNN.txx"
#include "otbTrainNeuralNetwork.txx"
#include "otbTrainNormalBayes.txx"
#include "otbTrainRandomForests.txx"
#include "otbTrainSVM.cxx"
#endif
#ifdef OTB_USE_LIBSVM
#include "otbTrainLibSVM.txx"
#endif
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbLearningApplicationBase_txx
#define __otbLearningApplicationBase_txx
#include "otbLearningApplicationBase.h"
namespace otb
{
namespace Wrapper
{
template <class TInputValue, class TOutputValue>
LearningApplicationBase<TInputValue,TOutputValue>
::LearningApplicationBase() : m_RegressionFlag(false)
{
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::DoInit()
{
AddDocTag(Tags::Learning);
AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
//Group LibSVM
#ifdef OTB_USE_LIBSVM
InitLibSVMParams();
#endif
#ifdef OTB_USE_OPENCV
InitSVMParams();
InitBoostParams();
InitDecisionTreeParams();
InitGradientBoostedTreeParams();
InitNeuralNetworkParams();
InitNormalBayesParams();
InitRandomForestsParams();
InitKNNParams();
#endif
AddRANDParameter();
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::Classify(ListSampleType::Pointer validationListSample,
TargetListSampleType::Pointer predictedList,
std::string modelPath)
{
//Classification
ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath,
ModelFactoryType::ReadMode);
if (model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << modelPath);
}
model->Load(modelPath);
model->SetInputListSample(validationListSample);
model->SetTargetListSample(predictedList);
model->PredictAll();
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::Train(ListSampleType::Pointer trainingListSample,
TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath,
std::string modelName)
{
if (modelName == "libsvm")
{
#ifdef OTB_USE_LIBSVM
TrainLibSVM(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration.");
#endif
}
else if (modelName == "svm")
{
#ifdef OTB_USE_OPENCV
TrainSVM(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "boost")
{
#ifdef OTB_USE_OPENCV
TrainBoost(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "dt")
{
#ifdef OTB_USE_OPENCV
TrainDecisionTree(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "gbt")
{
#ifdef OTB_USE_OPENCV
TrainGradientBoostedTree(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "ann")
{
#ifdef OTB_USE_OPENCV
TrainNeuralNetwork(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "bayes")
{
#ifdef OTB_USE_OPENCV
TrainNormalBayes(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "rf")
{
#ifdef OTB_USE_OPENCV
TrainRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
else if (modelName == "knn")
{
#ifdef OTB_USE_OPENCV
TrainKNN(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
}
}
}
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment