diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h new file mode 100644 index 0000000000000000000000000000000000000000..d675d6a674b8e4187cf405a085bab892d847aafd --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h @@ -0,0 +1,196 @@ +/*========================================================================= + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#ifndef __otbLearningApplicationBase_h +#define __otbLearningApplicationBase_h + +#include "otbConfigure.h" + +#include "otbWrapperApplicationFactory.h" + +#include <iostream> + +// ListSample +#include "otbListSampleGenerator.h" +#include "itkVariableLengthVector.h" + +//Estimator +#include "otbMachineLearningModelFactory.h" + +#ifdef OTB_USE_OPENCV +# include "otbKNearestNeighborsMachineLearningModel.h" +# include "otbRandomForestsMachineLearningModel.h" +# include "otbSVMMachineLearningModel.h" +# include "otbBoostMachineLearningModel.h" +# include "otbDecisionTreeMachineLearningModel.h" +# include "otbGradientBoostedTreeMachineLearningModel.h" +# include "otbNormalBayesMachineLearningModel.h" +# include "otbNeuralNetworkMachineLearningModel.h" +#endif + +#ifdef OTB_USE_LIBSVM +#include "otbLibSVMMachineLearningModel.h" +#endif + +// Statistic XML Reader +#include "otbStatisticsXMLFileReader.h" + +#include "itkTimeProbe.h" +#include "otbStandardFilterWatcher.h" + +// Normalize the samples +#include "otbShiftScaleSampleListFilter.h" + +// List sample concatenation +#include "otbConcatenateSampleListFilter.h" + +// Balancing ListSample +#include "otbListSampleToBalancedListSampleFilter.h" + +namespace otb +{ +namespace Wrapper +{ + +template <class TInputValue, class TOutputValue> +class LearningApplicationBase: public Application +{ +public: + /** Standard class typedefs. */ + typedef LearningApplicationBase Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Standard macro */ + itkNewMacro(Self) + + itkTypeMacro(LearningApplicationBase, otb::Application) + + typedef TInputValue InputValueType; + typedef TOutputValue OutputValueType; + + typedef otb::VectorImage<InputValueType> SampleImageType; + typedef typename SampleImageType::PixelType PixelType; + + // Machine Learning models + typedef otb::MachineLearningModelFactory< + InputValueType, OutputValueType> ModelFactoryType; + typedef typename ModelFactoryType::MachineLearningModelTypePointer ModelPointerType; + typedef typename ModelFactoryType::MachineLearningModel ModelType; + + typedef typename ModelType::InputSampleType SampleType; + typedef typename ModelType::InputListSampleType ListSampleType; + + typedef typename ModelType::TargetSampleType TargetSampleType; + typedef typename ModelType::TargetListSampleType TargetListSampleType; + + // SampleList manipulation + typedef otb::ListSampleGenerator<SampleImageType, VectorDataType> ListSampleGeneratorType; + + typedef ListSampleGeneratorType::LabelType LabelType; + typedef ListSampleGeneratorType::ListLabelType LabelListSampleType; + typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType; + typedef otb::Statistics::ConcatenateSampleListFilter<LabelListSampleType> ConcatenateLabelListSampleFilterType; + + // Statistic XML file Reader + typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader; + + // Enhance List Sample typedef otb::Statistics::ListSampleToBalancedListSampleFilter<ListSampleType, LabelListSampleType> BalancingListSampleFilterType; + typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType; + + +#ifdef OTB_USE_OPENCV + typedef otb::RandomForestsMachineLearningModel<InputValueType, OutputValueType> RandomForestType; + typedef otb::KNearestNeighborsMachineLearningModel<InputValueType, OutputValueType> KNNType; + typedef otb::SVMMachineLearningModel<InputValueType, OutputValueType> SVMType; + typedef otb::BoostMachineLearningModel<InputValueType, OutputValueType> BoostType; + typedef otb::DecisionTreeMachineLearningModel<InputValueType, OutputValueType> DecisionTreeType; + typedef otb::GradientBoostedTreeMachineLearningModel<InputValueType, OutputValueType> GradientBoostedTreeType; + typedef otb::NeuralNetworkMachineLearningModel<InputValueType, OutputValueType> NeuralNetworkType; + typedef otb::NormalBayesMachineLearningModel<InputValueType, OutputValueType> NormalBayesType; +#endif + +#ifdef OTB_USE_LIBSVM + typedef otb::LibSVMMachineLearningModel<InputValueType, OutputValueType> LibSVMType; +#endif + +protected: + using Superclass::AddParameter; + friend void InitSVMParams(LearningApplicationBase & app); + +private: + void DoInit(); + +#ifdef OTB_USE_LIBSVM + void InitLibSVMParams(); +#endif + +#ifdef OTB_USE_OPENCV + void InitBoostParams(); + void InitSVMParams(); + void InitDecisionTreeParams(); + void InitGradientBoostedTreeParams(); + void InitNeuralNetworkParams(); + void InitNormalBayesParams(); + void InitRandomForestsParams(); + void InitKNNParams(); +#endif + +#ifdef OTB_USE_LIBSVM + void TrainLibSVM(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample); +#endif + +#ifdef OTB_USE_OPENCV + void TrainBoost(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainSVM(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainDecisionTree(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainGradientBoostedTree(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainNeuralNetwork(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainNormalBayes(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainRandomForests(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void TrainKNN(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); +#endif + + void Train(ListSampleType::Pointer trainingListSample, TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath, std::string modelName); + + void Classify(ListSampleType::Pointer validationListSample, TargetListSampleType::Pointer predictedList, std::string modelPath); + + bool m_RegressionFlag; + +}; + +} +} + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbLearningApplicationBase.txx" +#ifdef OTB_USE_OPENCV +#include "otbTrainBoost.txx" +#include "otbTrainDecisionTree.txx" +#include "otbTrainGradientBoostedTree.txx" +#include "otbTrainKNN.txx" +#include "otbTrainNeuralNetwork.txx" +#include "otbTrainNormalBayes.txx" +#include "otbTrainRandomForests.txx" +#include "otbTrainSVM.cxx" +#endif +#ifdef OTB_USE_LIBSVM +#include "otbTrainLibSVM.txx" +#endif +#endif + +#endif diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx new file mode 100644 index 0000000000000000000000000000000000000000..062281edc31b5c755f914d8fddc8110d8536d886 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx @@ -0,0 +1,169 @@ +/*========================================================================= + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#ifndef __otbLearningApplicationBase_txx +#define __otbLearningApplicationBase_txx + +#include "otbLearningApplicationBase.h" + +namespace otb +{ +namespace Wrapper +{ + +template <class TInputValue, class TOutputValue> +LearningApplicationBase<TInputValue,TOutputValue> +::LearningApplicationBase() : m_RegressionFlag(false) +{ +} + +template <class TInputValue, class TOutputValue> +void +LearningApplicationBase<TInputValue,TOutputValue> +::DoInit() +{ + AddDocTag(Tags::Learning); + + AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training"); + SetParameterDescription("classifier", "Choice of the classifier to use for the training."); + + //Group LibSVM +#ifdef OTB_USE_LIBSVM + InitLibSVMParams(); +#endif + +#ifdef OTB_USE_OPENCV + InitSVMParams(); + InitBoostParams(); + InitDecisionTreeParams(); + InitGradientBoostedTreeParams(); + InitNeuralNetworkParams(); + InitNormalBayesParams(); + InitRandomForestsParams(); + InitKNNParams(); +#endif + + AddRANDParameter(); +} + +template <class TInputValue, class TOutputValue> +void +LearningApplicationBase<TInputValue,TOutputValue> +::Classify(ListSampleType::Pointer validationListSample, + TargetListSampleType::Pointer predictedList, + std::string modelPath) +{ + //Classification + ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath, + ModelFactoryType::ReadMode); + + if (model.IsNull()) + { + otbAppLogFATAL(<< "Error when loading model " << modelPath); + } + + model->Load(modelPath); + model->SetInputListSample(validationListSample); + model->SetTargetListSample(predictedList); + model->PredictAll(); +} + +template <class TInputValue, class TOutputValue> +void +LearningApplicationBase<TInputValue,TOutputValue> +::Train(ListSampleType::Pointer trainingListSample, + TargetListSampleType::Pointer trainingLabeledListSample, + std::string modelPath, + std::string modelName) +{ + if (modelName == "libsvm") + { + #ifdef OTB_USE_LIBSVM + TrainLibSVM(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration."); + #endif + } + else if (modelName == "svm") + { + #ifdef OTB_USE_OPENCV + TrainSVM(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "boost") + { + #ifdef OTB_USE_OPENCV + TrainBoost(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "dt") + { + #ifdef OTB_USE_OPENCV + TrainDecisionTree(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "gbt") + { + #ifdef OTB_USE_OPENCV + TrainGradientBoostedTree(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "ann") + { + #ifdef OTB_USE_OPENCV + TrainNeuralNetwork(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "bayes") + { + #ifdef OTB_USE_OPENCV + TrainNormalBayes(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "rf") + { + #ifdef OTB_USE_OPENCV + TrainRandomForests(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } + else if (modelName == "knn") + { + #ifdef OTB_USE_OPENCV + TrainKNN(trainingListSample, trainingLabeledListSample, modelPath); + #else + otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration."); + #endif + } +} + +} +} + +#endif