From 58bcfd76ac279b587b979208240e041ff336e6ac Mon Sep 17 00:00:00 2001
From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr>
Date: Thu, 11 Dec 2014 15:58:49 +0100
Subject: [PATCH] ENH: add regression mode to Random Forests

---
 .../otbRandomForestsMachineLearningModel.h     |  4 ++++
 .../otbRandomForestsMachineLearningModel.txx   | 18 +++++++++++-------
 2 files changed, 15 insertions(+), 7 deletions(-)

diff --git a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h
index 2906f18c0a..917202e95c 100644
--- a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h
+++ b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h
@@ -161,6 +161,9 @@ public:
   /* The type of the termination criteria */
   itkGetMacro(TerminationCriteria, int);
   itkSetMacro(TerminationCriteria, int);
+  /* Perform regression instead of classification */
+  itkGetMacro(RegressionMode, bool);
+  itkSetMacro(RegressionMode, bool);
 
   /** Returns a matrix containing variable importance */
   VariableImportanceMatrixType GetVariableImportance();
@@ -199,6 +202,7 @@ private:
   int m_MaxNumberOfTrees;
   float m_ForestAccuracy;
   int m_TerminationCriteria;
+  bool m_RegressionMode;
 };
 } // end namespace otb
 
diff --git a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx
index 001096c81b..78caff9866 100644
--- a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx
+++ b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx
@@ -40,7 +40,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
  m_MaxNumberOfVariables(0),
  m_MaxNumberOfTrees(100),
  m_ForestAccuracy(0.01),
- m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
+ m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
+ m_RegressionMode(false)
 {
 }
 
@@ -95,11 +96,14 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
   cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
   var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
 
-  var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
+  if(m_RegressionMode)
+    var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL;
+  else
+    var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
 
   //train the RT model
   m_RFModel->train(samples, CV_ROW_SAMPLE, labels,
-              cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
+                   cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
 }
 
 template <class TInputValue, class TOutputValue>
@@ -113,13 +117,13 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
 
   otb::SampleToMat<InputSampleType>(value,sample);
 
-    double result = m_RFModel->predict(sample);
+  double result = m_RFModel->predict(sample);
 
-    TargetSampleType target;
+  TargetSampleType target;
 
-    target[0] = static_cast<TOutputValue>(result);
+  target[0] = static_cast<TOutputValue>(result);
 
-    return target[0];
+  return target[0];
 }
 
 template <class TInputValue, class TOutputValue>
-- 
GitLab