diff --git a/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModel.h new file mode 100644 index 0000000000000000000000000000000000000000..5abebfbf4999f37b1517fe83d3c9226a272d98b6 --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModel.h @@ -0,0 +1,208 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbDecisionTreeMachineLearningModel_h +#define __otbDecisionTreeMachineLearningModel_h + +#include "itkLightObject.h" +#include "itkVariableLengthVector.h" +#include "itkFixedArray.h" +#include "itkListSample.h" +#include "otbMachineLearningModel.h" + + +class CvDTree; + +namespace otb +{ +template <class TInputValue, class TTargetValue> +class ITK_EXPORT DecisionTreeMachineLearningModel + : public MachineLearningModel <TInputValue, TTargetValue> +{ +public: + /** Standard class typedefs. */ + typedef DecisionTreeMachineLearningModel Self; + typedef MachineLearningModel<TInputValue, TTargetValue> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + // Input related typedefs + typedef TInputValue InputValueType; + typedef itk::VariableLengthVector<InputValueType> InputSampleType; + typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; + + // Target related typedefs + typedef TTargetValue TargetValueType; + typedef itk::FixedArray<TargetValueType,1> TargetSampleType; + typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + + /** Run-time type information (and related methods). */ + itkNewMacro(Self); + itkTypeMacro(DecisionTreeMachineLearningModel, itk::MachineLearningModel); + + /** Setters/Getters to the maximum depth of the tree. + * Default is INT_MAX + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(MaxDepth, int); + itkSetMacro(MaxDepth, int); + + /** Setters/Getters to the minimum number of sample in each node. + * Default is 10 + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(MinSampleCount, int); + itkSetMacro(MinSampleCount, int); + + /** Termination Criteria for regression tree. + * If all absolute differences between an estimated value in a node + * and values of train samples in this node are less than this parameter + * then the node will not be split. + * Default is 0.01 + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(RegressionAccuracy, double); + itkSetMacro(RegressionAccuracy, double); + + /** If true then surrogate splits will be built. + * These splits allow to work with missing data and compute variable importance correctly. + * Default is true + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(UseSurrogates, bool); + itkSetMacro(UseSurrogates, bool); + + /** Cluster possible values of a categorical variable into K \leq max_categories clusters to find + * a suboptimal split. If a discrete variable, on which the training procedure tries to make a split, + * takes more than max_categories values, the precise best subset estimation may take a very long time + * because the algorithm is exponential. Instead, many decision trees engines (including ML) try to find + * sub-optimal split in this case by clustering all the samples into max_categories clusters + * that is some categories are merged together. The clustering is applied only in n>2-class classification problems + * for categorical variables with N > max_categories possible values. In case of regression and 2-class classification + * the optimal split can be found efficiently without employing clustering, thus the parameter is not used in these cases. + * Default is 10 + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(MaxCategories, int); + itkSetMacro(MaxCategories, int); + + /** If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds. + * Default is 10 + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(CVFolds, int); + itkSetMacro(CVFolds, int); + + /** If true then a pruning will be harsher. This will make a tree more compact and + * more resistant to the training data noise but a bit less accurate. + * Default is true + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(Use1seRule, bool); + itkSetMacro(Use1seRule, bool); + + /** If true then pruned branches are physically removed from the tree. + * Otherwise they are retained and it is possible to get results + * from the original unpruned (or pruned less aggressively) tree by decreasing CvDTree::pruned_tree_idx parameter. + * Default is true + * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29 + */ + itkGetMacro(TruncatePrunedTree, bool); + itkSetMacro(TruncatePrunedTree, bool); + + + /* The array of a priori class probabilities, sorted by the class label + * value. The parameter can be used to tune the decision tree preferences toward + * a certain class. For example, if you want to detect some rare anomaly + * occurrence, the training base will likely contain much more normal cases than + * anomalies, so a very good classification performance will be achieved just by + * considering every case as normal. To avoid this, the priors can be specified, + * where the anomaly probability is artificially increased (up to 0.5 or even + * greater), so the weight of the misclassified anomalies becomes much bigger, + * and the tree is adjusted properly. You can also think about this parameter as + * weights of prediction categories which determine relative weights that you + * give to misclassification. That is, if the weight of the first category is 1 + * and the weight of the second category is 10, then each mistake in predicting + * the second category is equivalent to making 10 mistakes in predicting the + first category. */ + + std::vector<float> GetPriors() const + { + return m_Priors; + } + + /** Setters/Getters to IsRegression flag + * Default is False + */ + itkGetMacro(IsRegression, bool); + itkSetMacro(IsRegression, bool); + + /** Train the machine learning model */ + virtual void Train(); + + /** Predict values using the model */ + virtual TargetSampleType Predict(const InputSampleType & input) const; + + /** Save the model to file */ + virtual void Save(const std::string & filename, const std::string & name=""); + + /** Load the model from file */ + virtual void Load(const std::string & filename, const std::string & name=""); + + /** Determine the file type. Returns true if this ImageIO can read the + * file specified. */ + virtual bool CanReadFile(const std::string &); + + /** Determine the file type. Returns true if this ImageIO can write the + * file specified. */ + virtual bool CanWriteFile(const std::string &); + +protected: + /** Constructor */ + DecisionTreeMachineLearningModel(); + + /** Destructor */ + virtual ~DecisionTreeMachineLearningModel(); + + /** PrintSelf method */ + void PrintSelf(std::ostream& os, itk::Indent indent) const; + +private: + DecisionTreeMachineLearningModel(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + CvDTree * m_DTreeModel; + + int m_MaxDepth; + int m_MinSampleCount; + double m_RegressionAccuracy; + bool m_UseSurrogates; + int m_MaxCategories; + int m_CVFolds; + bool m_Use1seRule; + bool m_IsRegression; + bool m_TruncatePrunedTree; + std::vector<float> m_Priors; + +}; +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbDecisionTreeMachineLearningModel.txx" +#endif + +#endif diff --git a/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModel.txx new file mode 100644 index 0000000000000000000000000000000000000000..4b2af6c6c4cd22af115224e2dfd571443f91662c --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModel.txx @@ -0,0 +1,166 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbDecisionTreeMachineLearningModel_txx +#define __otbDecisionTreeMachineLearningModel_txx + +#include "otbDecisionTreeMachineLearningModel.h" +#include "otbOpenCVUtils.h" + +#include <opencv2/opencv.hpp> + +namespace otb +{ + +template <class TInputValue, class TOutputValue> +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::DecisionTreeMachineLearningModel() : + m_MaxDepth(INT_MAX), m_MinSampleCount(10), m_RegressionAccuracy(0.01), + m_UseSurrogates(true), m_MaxCategories(10), m_CVFolds(10), + m_Use1seRule(true), m_IsRegression(false), m_TruncatePrunedTree(true) +{ + m_DTreeModel = new CvDTree; +} + + +template <class TInputValue, class TOutputValue> +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::~DecisionTreeMachineLearningModel() +{ + delete m_DTreeModel; +} + +/** Train the machine learning model */ +template <class TInputValue, class TOutputValue> +void +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::Train() +{ + //convert listsample to opencv matrix + cv::Mat samples; + otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples); + + cv::Mat labels; + otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels); + + float * priors = m_Priors.empty() ? 0 : &m_Priors.front(); + + CvDTreeParams params = CvDTreeParams(m_MaxDepth, m_MinSampleCount, m_RegressionAccuracy, + m_UseSurrogates, m_MaxCategories, m_CVFolds, m_Use1seRule, m_TruncatePrunedTree, priors); + + //train the Decision Tree model + cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U ); + var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical + + if (!m_IsRegression) //Classification + var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL; + + m_DTreeModel->train(samples,CV_ROW_SAMPLE,labels,cv::Mat(),cv::Mat(),var_type,cv::Mat(),params); +} + +template <class TInputValue, class TOutputValue> +typename DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::TargetSampleType +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::Predict(const InputSampleType & input) const +{ + //convert listsample to Mat + cv::Mat sample; + + otb::SampleToMat<InputSampleType>(input,sample); + + double result = m_DTreeModel->predict(sample, cv::Mat(), false)->value; + + TargetSampleType target; + + target[0] = static_cast<TOutputValue>(result); + + return target; +} + +template <class TInputValue, class TOutputValue> +void +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::Save(const std::string & filename, const std::string & name) +{ + if (name == "") + m_DTreeModel->save(filename.c_str(), 0); + else + m_DTreeModel->save(filename.c_str(), name.c_str()); +} + +template <class TInputValue, class TOutputValue> +void +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::Load(const std::string & filename, const std::string & name) +{ + if (name == "") + m_DTreeModel->load(filename.c_str(), 0); + else + m_DTreeModel->load(filename.c_str(), name.c_str()); +} + +template <class TInputValue, class TOutputValue> +bool +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::CanReadFile(const std::string & file) +{ + std::ifstream ifs; + ifs.open(file.c_str()); + + if(!ifs) + { + std::cerr<<"Could not read file "<<file<<std::endl; + return false; + } + + while (!ifs.eof()) + { + std::string line; + std::getline(ifs, line); + + //if (line.find(m_SVMModel->getName()) != std::string::npos) + if (line.find(CV_TYPE_NAME_ML_TREE) != std::string::npos) + { + std::cout<<"Reading a "<<CV_TYPE_NAME_ML_TREE<<" model !!!"<<std::endl; + return true; + } + } + ifs.close(); + return false; +} + +template <class TInputValue, class TOutputValue> +bool +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::CanWriteFile(const std::string & file) +{ + return false; +} + +template <class TInputValue, class TOutputValue> +void +DecisionTreeMachineLearningModel<TInputValue,TOutputValue> +::PrintSelf(std::ostream& os, itk::Indent indent) const +{ + // Call superclass implementation + Superclass::PrintSelf(os,indent); +} + +} //end namespace otb + +#endif diff --git a/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModelFactory.h b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModelFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..c04b7096c856bd97d7a9728b91d2acb050e56356 --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModelFactory.h @@ -0,0 +1,72 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbDecisionTreeMachineLearningModelFactory_h +#define __otbDecisionTreeMachineLearningModelFactory_h + +#include "itkObjectFactoryBase.h" +#include "itkImageIOBase.h" + +namespace otb +{ +/** \class DecisionTreeMachineLearningModelFactory + * \brief Creation d'un instance d'un objet SVMMachineLearningModel utilisant les object factory. + */ +template <class TInputValue, class TTargetValue> +class ITK_EXPORT DecisionTreeMachineLearningModelFactory : public itk::ObjectFactoryBase +{ +public: + /** Standard class typedefs. */ + typedef DecisionTreeMachineLearningModelFactory Self; + typedef itk::ObjectFactoryBase Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Class methods used to interface with the registered factories. */ + virtual const char* GetITKSourceVersion(void) const; + virtual const char* GetDescription(void) const; + + /** Method for class instantiation. */ + itkFactorylessNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(DecisionTreeMachineLearningModelFactory, itk::ObjectFactoryBase); + + /** Register one factory of this type */ + static void RegisterOneFactory(void) + { + DecisionTreeMachineLearningModelFactory::Pointer Factory = DecisionTreeMachineLearningModelFactory::New(); + itk::ObjectFactoryBase::RegisterFactory(Factory); + } + +protected: + DecisionTreeMachineLearningModelFactory(); + virtual ~DecisionTreeMachineLearningModelFactory(); + +private: + DecisionTreeMachineLearningModelFactory(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + +}; + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbDecisionTreeMachineLearningModelFactory.txx" +#endif + +#endif diff --git a/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModelFactory.txx b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModelFactory.txx new file mode 100644 index 0000000000000000000000000000000000000000..682b0a9f79e8582e96be7d2ab2aa60855448d7df --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbDecisionTreeMachineLearningModelFactory.txx @@ -0,0 +1,64 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbDecisionTreeMachineLearningModelFactory.h" + +#include "itkCreateObjectFunction.h" +#include "otbDecisionTreeMachineLearningModel.h" +#include "itkVersion.h" + +namespace otb +{ + +template <class TInputValue, class TOutputValue> +DecisionTreeMachineLearningModelFactory<TInputValue,TOutputValue> +::DecisionTreeMachineLearningModelFactory() +{ + + static std::string classOverride = std::string("otbMachineLearningModel"); + static std::string subclass = std::string("otbDecisionTreeMachineLearningModel"); + + this->RegisterOverride(classOverride.c_str(), + subclass.c_str(), + "Decision Tree ML Model", + 1, + itk::CreateObjectFunction<DecisionTreeMachineLearningModel<TInputValue,TOutputValue> >::New()); +} + +template <class TInputValue, class TOutputValue> +DecisionTreeMachineLearningModelFactory<TInputValue,TOutputValue> +::~DecisionTreeMachineLearningModelFactory() +{ +} + +template <class TInputValue, class TOutputValue> +const char* +DecisionTreeMachineLearningModelFactory<TInputValue,TOutputValue> +::GetITKSourceVersion(void) const +{ + return ITK_SOURCE_VERSION; +} + +template <class TInputValue, class TOutputValue> +const char* +DecisionTreeMachineLearningModelFactory<TInputValue,TOutputValue> +::GetDescription() const +{ + return "Decision Tree machine learning model factory"; +} + +} // end namespace otb diff --git a/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx b/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx index 003bf6cd1dc1a55317151cf0f8abf4f29996c255..616a16f6c9d3537a560bfb947bba3a678ca3ca6f 100644 --- a/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx +++ b/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx @@ -26,6 +26,7 @@ #include "otbBoostMachineLearningModelFactory.h" #include "otbNeuralNetworkMachineLearningModelFactory.h" #include "otbNormalBayesMachineLearningModelFactory.h" +#include "otbDecisionTreeMachineLearningModelFactory.h" namespace otb @@ -101,6 +102,7 @@ MachineLearningModelFactory<TInputValue,TOutputValue> itk::ObjectFactoryBase::RegisterFactory(BoostMachineLearningModelFactory<TInputValue,TOutputValue>::New()); itk::ObjectFactoryBase::RegisterFactory(NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>::New()); itk::ObjectFactoryBase::RegisterFactory(NormalBayesMachineLearningModelFactory<TInputValue,TOutputValue>::New()); + itk::ObjectFactoryBase::RegisterFactory(DecisionTreeMachineLearningModelFactory<TInputValue,TOutputValue>::New()); firstTime = false; } diff --git a/Testing/Code/Learning/CMakeLists.txt b/Testing/Code/Learning/CMakeLists.txt index 023b8ccc19db8385b80a09393c7e1b143e5a110c..43885d30cb381c88f54d7c8db1ea8ffca37deeaa 100644 --- a/Testing/Code/Learning/CMakeLists.txt +++ b/Testing/Code/Learning/CMakeLists.txt @@ -750,14 +750,29 @@ IF(OTB_USE_OPENCV) ${TEMP}/ANNMachineLearningModel.txt ) + ADD_TEST(leTuNormalBayesMachineLearningModelNew ${LEARNING_TESTS6} + otbNormalBayesMachineLearningModelNew) + ADD_TEST(leTvNormalBayesMachineLearningModel ${LEARNING_TESTS6} - #--compare-ascii ${NOTOL} - #${BASELINE_FILES}/NormalBayesMachineLearningModel.txt - #${TEMP}/NormalBayesMachineLearningModel.txt + --compare-ascii ${NOTOL} + ${BASELINE_FILES}/NormalBayesMachineLearningModel.txt + ${TEMP}/NormalBayesMachineLearningModel.txt otbNormalBayesMachineLearningModel ${INPUTDATA}/letter.scale ${TEMP}/NormalBayesMachineLearningModel.txt ) + + ADD_TEST(leTuDecisionTreeMachineLearningModelNew ${LEARNING_TESTS6} + otbDecisionTreeMachineLearningModelNew) + + ADD_TEST(leTvDecisionTreeMachineLearningModel ${LEARNING_TESTS6} + --compare-ascii ${NOTOL} + ${BASELINE_FILES}/DecisionTreeMachineLearningModel.txt + ${TEMP}/DecisionTreeMachineLearningModel.txt + otbDecisionTreeMachineLearningModel + ${INPUTDATA}/letter.scale + ${TEMP}/DecisionTreeMachineLearningModel.txt + ) ADD_TEST(leTuImageClassificationFilterNew ${LEARNING_TESTS6} otbImageClassificationFilterNew) @@ -813,7 +828,11 @@ IF(OTB_USE_OPENCV) otbNormalBayesMachineLearningModelCanRead ${INPUTDATA}/NormalBayesMachineLearningModel.txt ) - + + ADD_TEST(leTuDecisionTreeMachineLearningModelCanRead ${LEARNING_TESTS6} + otbDecisionTreeMachineLearningModelCanRead + ${INPUTDATA}/DecisionTreeMachineLearningModel.txt + ) ENDIF(OTB_USE_OPENCV) diff --git a/Testing/Code/Learning/otbLearningTests6.cxx b/Testing/Code/Learning/otbLearningTests6.cxx index 5594144620a399cecb655168ca6e61ad774423b7..4c324d43b2191156b70fbde9f005831c1b740d8d 100644 --- a/Testing/Code/Learning/otbLearningTests6.cxx +++ b/Testing/Code/Learning/otbLearningTests6.cxx @@ -38,6 +38,8 @@ void RegisterTests() REGISTER_TEST(otbANNMachineLearningModel); REGISTER_TEST(otbNormalBayesMachineLearningModelNew); REGISTER_TEST(otbNormalBayesMachineLearningModel); + REGISTER_TEST(otbDecisionTreeMachineLearningModelNew); + REGISTER_TEST(otbDecisionTreeMachineLearningModel); REGISTER_TEST(otbImageClassificationFilterNew); REGISTER_TEST(otbImageClassificationFilter); REGISTER_TEST(otbLibSVMMachineLearningModelCanRead); diff --git a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx index 578fab35924458a8bf67a4d4bd2f7750298e8c35..04de2b31b04294d147d20f89a218e3bcb3190f15 100644 --- a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx +++ b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx @@ -16,6 +16,7 @@ =========================================================================*/ +#include <iostream> #include "otbMachineLearningModel.h" #include "otbLibSVMMachineLearningModel.h" #include "otbSVMMachineLearningModel.h" @@ -23,7 +24,7 @@ #include "otbBoostMachineLearningModel.h" #include "otbNeuralNetworkMachineLearningModel.h" #include "otbNormalBayesMachineLearningModel.h" -#include <iostream> +#include "otbDecisionTreeMachineLearningModel.h" typedef otb::MachineLearningModel<float,short> MachineLearningModelType; typedef MachineLearningModelType::InputValueType InputValueType; @@ -189,3 +190,29 @@ int otbNormalBayesMachineLearningModelCanRead(int argc, char* argv[]) return EXIT_SUCCESS; } +int otbDecisionTreeMachineLearningModelCanRead(int argc, char* argv[]) +{ + if (argc != 2) + { + std::cerr << "Usage: " << argv[0] + << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for (int i = 1; i < argc; ++i) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename(argv[1]); + typedef otb::DecisionTreeMachineLearningModel<InputValueType, TargetValueType> DecisionTreeType; + DecisionTreeType::Pointer classifier = DecisionTreeType::New(); + bool lCanRead = classifier->CanReadFile(filename); + if (lCanRead == false) + { + std::cerr << "Erreur otb::DecisionTreeMachineLearningModel : impossible to open the file " << filename << "." << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + diff --git a/Testing/Code/Learning/otbTrainMachineLearningModel.cxx b/Testing/Code/Learning/otbTrainMachineLearningModel.cxx index 3f11e9991ecc082de62d0a0bf102ada754047d98..0aa9e3be0517cbf3f3ea7cf22aeafa4bf50da0f0 100644 --- a/Testing/Code/Learning/otbTrainMachineLearningModel.cxx +++ b/Testing/Code/Learning/otbTrainMachineLearningModel.cxx @@ -28,6 +28,7 @@ #include "otbBoostMachineLearningModel.h" #include "otbNeuralNetworkMachineLearningModel.h" #include "otbNormalBayesMachineLearningModel.h" +#include "otbDecisionTreeMachineLearningModel.h" #include "otbConfusionMatrixCalculator.h" @@ -531,6 +532,59 @@ int otbNormalBayesMachineLearningModel(int argc, char * argv[]) return EXIT_SUCCESS; } +int otbDecisionTreeMachineLearningModelNew(int argc, char * argv[]) +{ + typedef otb::DecisionTreeMachineLearningModel<InputValueType,TargetValueType> DecisionTreeType; + DecisionTreeType::Pointer classifier = DecisionTreeType::New(); + return EXIT_SUCCESS; +} + +int otbDecisionTreeMachineLearningModel(int argc, char * argv[]) +{ + if (argc != 3 ) + { + std::cout<<"Wrong number of arguments "<<std::endl; + std::cout<<"Usage : sample file, output file "<<std::endl; + return EXIT_FAILURE; + } + + typedef otb::DecisionTreeMachineLearningModel<InputValueType, TargetValueType> DecisionTreeType; + + InputListSampleType::Pointer samples = InputListSampleType::New(); + TargetListSampleType::Pointer labels = TargetListSampleType::New(); + TargetListSampleType::Pointer predicted = TargetListSampleType::New(); + + if(!ReadDataFile(argv[1],samples,labels)) + { + std::cout<<"Failed to read samples file "<<argv[1]<<std::endl; + return EXIT_FAILURE; + } + + DecisionTreeType::Pointer classifier = DecisionTreeType::New(); + classifier->SetInputListSample(samples); + classifier->SetTargetListSample(labels); + classifier->Train(); + + classifier->SetTargetListSample(predicted); + classifier->PredictAll(); + + classifier->Save(argv[2]); + + ConfusionMatrixCalculatorType::Pointer cmCalculator = ConfusionMatrixCalculatorType::New(); + + cmCalculator->SetProducedLabels(predicted); + cmCalculator->SetReferenceLabels(labels); + cmCalculator->Compute(); + + std::cout<<"Confusion matrix: "<<std::endl; + std::cout<<cmCalculator->GetConfusionMatrix()<<std::endl; + std::cout<<"Kappa: "<<cmCalculator->GetKappaIndex()<<std::endl; + std::cout<<"Overall Accuracy: "<<cmCalculator->GetOverallAccuracy()<<std::endl; + + return EXIT_SUCCESS; +} + +