diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index 58726884b11d83f0dd371ebda6e862f2df8d0b0e..f5457406f86468eaef2956018622401cc416c337 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -121,6 +121,9 @@ public: itkGetMacro(TerminationCriteria, int); itkSetMacro(TerminationCriteria, int); + itkGetMacro(CoputeMargin, bool); + itkSetMacro(ComputeMargin, bool); + /** Returns a matrix containing variable importance */ VariableImportanceMatrixType GetVariableImportance(); @@ -206,6 +209,10 @@ private: float m_ForestAccuracy; /** The type of the termination criteria */ int m_TerminationCriteria; + /** Wether to compute margin (difference in probability between the + * 2 most voted classes) instead of confidence (probability of the most + * voted class) in prediction*/ + bool m_ComputeMargin; }; } // end namespace otb diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx index 67d21f8f493b120af39fe07f65b213af754420d6..aa0f054aa7a13940a2cd710e2adf326f01b9f3f5 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx @@ -39,7 +39,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_MaxNumberOfVariables(0), m_MaxNumberOfTrees(100), m_ForestAccuracy(0.01), - m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) + m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS), + m_ComputeMargin(false) { this->m_ConfidenceIndex = true; this->m_IsRegressionSupported = true; @@ -91,7 +92,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_MaxNumberOfTrees, // max number of trees in the forest m_ForestAccuracy, // forest accuracy m_TerminationCriteria // termination criteria - ); + ); cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U ); var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical @@ -125,7 +126,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> if (quality != NULL) { - (*quality) = m_RFModel->predict_confidence(sample); + if(m_ComputeMargin) + (*quality) = m_RFModel->predict_margin(sample); + else + (*quality) = m_RFModel->predict_confidence(sample); } return target[0]; @@ -176,12 +180,12 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> //if (line.find(m_RFModel->getName()) != std::string::npos) if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) { - //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; - return true; + //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; + return true; } - } - ifs.close(); - return false; + } + ifs.close(); + return false; } template <class TInputValue, class TOutputValue>