diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index ec62b761358634daa8e268c1488770bb801d3543..6f8a38d5fbe42c88af198a95388022fde3b45175 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -24,8 +24,9 @@ #include "itkFixedArray.h" #include "otbMachineLearningModel.h" #include "itkVariableSizeMatrix.h" +#include "otbCvRTrees.h" -class CvRTrees; +class CvRTreesWrapper; namespace otb { @@ -53,7 +54,7 @@ public: //opencv typedef - typedef CvRTrees RFType; + typedef CvRTreesWrapper RFType; /** Run-time type information (and related methods). */ itkNewMacro(Self); @@ -145,7 +146,7 @@ private: RandomForestsMachineLearningModel(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented - CvRTrees * m_RFModel; + CvRTreesWrapper * m_RFModel; /** The depth of the tree. A low value will likely underfit and conversely a * high value will likely overfit. The optimal value can be obtained using cross * validation or other suitable methods. */ @@ -189,7 +190,7 @@ private: * first category. */ std::vector<float> m_Priors; /** If true then variable importance will be calculated and then it can be - * retrieved by CvRTrees::get_var_importance(). */ + * retrieved by CvRTreesWrapper::get_var_importance(). */ bool m_CalculateVariableImportance; /** The size of the randomly selected subset of features at each tree node and * that are used to find the best split(s). If you set it to 0 then the size will diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx index 78642f1212ac9d75dc75d3bd9e9482d73bc948dc..67d21f8f493b120af39fe07f65b213af754420d6 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx @@ -29,17 +29,17 @@ namespace otb template <class TInputValue, class TOutputValue> RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::RandomForestsMachineLearningModel() : - m_RFModel (new CvRTrees), - m_MaxDepth(5), - m_MinSampleCount(10), - m_RegressionAccuracy(0.01), - m_ComputeSurrogateSplit(false), - m_MaxNumberOfCategories(10), - m_CalculateVariableImportance(false), - m_MaxNumberOfVariables(0), - m_MaxNumberOfTrees(100), - m_ForestAccuracy(0.01), - m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) + m_RFModel (new CvRTreesWrapper), + m_MaxDepth(5), + m_MinSampleCount(10), + m_RegressionAccuracy(0.01), + m_ComputeSurrogateSplit(false), + m_MaxNumberOfCategories(10), + m_CalculateVariableImportance(false), + m_MaxNumberOfVariables(0), + m_MaxNumberOfTrees(100), + m_ForestAccuracy(0.01), + m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) { this->m_ConfidenceIndex = true; this->m_IsRegressionSupported = true; @@ -125,7 +125,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> if (quality != NULL) { - (*quality) = m_RFModel->predict_prob(sample); + (*quality) = m_RFModel->predict_confidence(sample); } return target[0]; @@ -158,23 +158,23 @@ bool RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::CanReadFile(const std::string & file) { - std::ifstream ifs; - ifs.open(file.c_str()); + std::ifstream ifs; + ifs.open(file.c_str()); - if(!ifs) - { - std::cerr<<"Could not read file "<<file<<std::endl; - return false; - } + if(!ifs) + { + std::cerr<<"Could not read file "<<file<<std::endl; + return false; + } - while (!ifs.eof()) - { - std::string line; - std::getline(ifs, line); + while (!ifs.eof()) + { + std::string line; + std::getline(ifs, line); - //if (line.find(m_RFModel->getName()) != std::string::npos) - if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) + //if (line.find(m_RFModel->getName()) != std::string::npos) + if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) { //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; return true;