From ecd14f9f926d7b54078fca6b2c0c6d57fe2b2aae Mon Sep 17 00:00:00 2001 From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr> Date: Wed, 21 Oct 2015 11:46:24 +0200 Subject: [PATCH] ENH: use the random fores confidence estimation in the MachineLearningModel --- .../otbRandomForestsMachineLearningModel.h | 9 ++-- .../otbRandomForestsMachineLearningModel.txx | 50 +++++++++---------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index ec62b76135..6f8a38d5fb 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -24,8 +24,9 @@ #include "itkFixedArray.h" #include "otbMachineLearningModel.h" #include "itkVariableSizeMatrix.h" +#include "otbCvRTrees.h" -class CvRTrees; +class CvRTreesWrapper; namespace otb { @@ -53,7 +54,7 @@ public: //opencv typedef - typedef CvRTrees RFType; + typedef CvRTreesWrapper RFType; /** Run-time type information (and related methods). */ itkNewMacro(Self); @@ -145,7 +146,7 @@ private: RandomForestsMachineLearningModel(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented - CvRTrees * m_RFModel; + CvRTreesWrapper * m_RFModel; /** The depth of the tree. A low value will likely underfit and conversely a * high value will likely overfit. The optimal value can be obtained using cross * validation or other suitable methods. */ @@ -189,7 +190,7 @@ private: * first category. */ std::vector<float> m_Priors; /** If true then variable importance will be calculated and then it can be - * retrieved by CvRTrees::get_var_importance(). */ + * retrieved by CvRTreesWrapper::get_var_importance(). */ bool m_CalculateVariableImportance; /** The size of the randomly selected subset of features at each tree node and * that are used to find the best split(s). If you set it to 0 then the size will diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx index 78642f1212..67d21f8f49 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx @@ -29,17 +29,17 @@ namespace otb template <class TInputValue, class TOutputValue> RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::RandomForestsMachineLearningModel() : - m_RFModel (new CvRTrees), - m_MaxDepth(5), - m_MinSampleCount(10), - m_RegressionAccuracy(0.01), - m_ComputeSurrogateSplit(false), - m_MaxNumberOfCategories(10), - m_CalculateVariableImportance(false), - m_MaxNumberOfVariables(0), - m_MaxNumberOfTrees(100), - m_ForestAccuracy(0.01), - m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) + m_RFModel (new CvRTreesWrapper), + m_MaxDepth(5), + m_MinSampleCount(10), + m_RegressionAccuracy(0.01), + m_ComputeSurrogateSplit(false), + m_MaxNumberOfCategories(10), + m_CalculateVariableImportance(false), + m_MaxNumberOfVariables(0), + m_MaxNumberOfTrees(100), + m_ForestAccuracy(0.01), + m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) { this->m_ConfidenceIndex = true; this->m_IsRegressionSupported = true; @@ -125,7 +125,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> if (quality != NULL) { - (*quality) = m_RFModel->predict_prob(sample); + (*quality) = m_RFModel->predict_confidence(sample); } return target[0]; @@ -158,23 +158,23 @@ bool RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::CanReadFile(const std::string & file) { - std::ifstream ifs; - ifs.open(file.c_str()); + std::ifstream ifs; + ifs.open(file.c_str()); - if(!ifs) - { - std::cerr<<"Could not read file "<<file<<std::endl; - return false; - } + if(!ifs) + { + std::cerr<<"Could not read file "<<file<<std::endl; + return false; + } - while (!ifs.eof()) - { - std::string line; - std::getline(ifs, line); + while (!ifs.eof()) + { + std::string line; + std::getline(ifs, line); - //if (line.find(m_RFModel->getName()) != std::string::npos) - if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) + //if (line.find(m_RFModel->getName()) != std::string::npos) + if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) { //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; return true; -- GitLab