diff --git a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h index 2906f18c0a9b7f7ba46faba3a9450b28e3e9b01e..917202e95c446092e6115de37c17f0db150e1694 100644 --- a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h +++ b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h @@ -161,6 +161,9 @@ public: /* The type of the termination criteria */ itkGetMacro(TerminationCriteria, int); itkSetMacro(TerminationCriteria, int); + /* Perform regression instead of classification */ + itkGetMacro(RegressionMode, bool); + itkSetMacro(RegressionMode, bool); /** Returns a matrix containing variable importance */ VariableImportanceMatrixType GetVariableImportance(); @@ -199,6 +202,7 @@ private: int m_MaxNumberOfTrees; float m_ForestAccuracy; int m_TerminationCriteria; + bool m_RegressionMode; }; } // end namespace otb diff --git a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx index 001096c81b3a56efc05217f9f207cda7f74572d6..78caff9866386f427e291b7d4fa478085563c106 100644 --- a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx +++ b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx @@ -40,7 +40,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_MaxNumberOfVariables(0), m_MaxNumberOfTrees(100), m_ForestAccuracy(0.01), - m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) + m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS), + m_RegressionMode(false) { } @@ -95,11 +96,14 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U ); var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical - var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL; + if(m_RegressionMode) + var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL; + else + var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL; //train the RT model m_RFModel->train(samples, CV_ROW_SAMPLE, labels, - cv::Mat(), cv::Mat(), var_type, cv::Mat(), params); + cv::Mat(), cv::Mat(), var_type, cv::Mat(), params); } template <class TInputValue, class TOutputValue> @@ -113,13 +117,13 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> otb::SampleToMat<InputSampleType>(value,sample); - double result = m_RFModel->predict(sample); + double result = m_RFModel->predict(sample); - TargetSampleType target; + TargetSampleType target; - target[0] = static_cast<TOutputValue>(result); + target[0] = static_cast<TOutputValue>(result); - return target[0]; + return target[0]; } template <class TInputValue, class TOutputValue>