Skip to content
Snippets Groups Projects
Commit 58bcfd76 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: add regression mode to Random Forests

parent 3922f575
No related branches found
No related tags found
No related merge requests found
...@@ -161,6 +161,9 @@ public: ...@@ -161,6 +161,9 @@ public:
/* The type of the termination criteria */ /* The type of the termination criteria */
itkGetMacro(TerminationCriteria, int); itkGetMacro(TerminationCriteria, int);
itkSetMacro(TerminationCriteria, int); itkSetMacro(TerminationCriteria, int);
/* Perform regression instead of classification */
itkGetMacro(RegressionMode, bool);
itkSetMacro(RegressionMode, bool);
/** Returns a matrix containing variable importance */ /** Returns a matrix containing variable importance */
VariableImportanceMatrixType GetVariableImportance(); VariableImportanceMatrixType GetVariableImportance();
...@@ -199,6 +202,7 @@ private: ...@@ -199,6 +202,7 @@ private:
int m_MaxNumberOfTrees; int m_MaxNumberOfTrees;
float m_ForestAccuracy; float m_ForestAccuracy;
int m_TerminationCriteria; int m_TerminationCriteria;
bool m_RegressionMode;
}; };
} // end namespace otb } // end namespace otb
......
...@@ -40,7 +40,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -40,7 +40,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_MaxNumberOfVariables(0), m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100), m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01), m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
m_RegressionMode(false)
{ {
} }
...@@ -95,11 +96,14 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -95,11 +96,14 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U ); cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL; if(m_RegressionMode)
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL;
else
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
//train the RT model //train the RT model
m_RFModel->train(samples, CV_ROW_SAMPLE, labels, m_RFModel->train(samples, CV_ROW_SAMPLE, labels,
cv::Mat(), cv::Mat(), var_type, cv::Mat(), params); cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
...@@ -113,13 +117,13 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -113,13 +117,13 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
otb::SampleToMat<InputSampleType>(value,sample); otb::SampleToMat<InputSampleType>(value,sample);
double result = m_RFModel->predict(sample); double result = m_RFModel->predict(sample);
TargetSampleType target; TargetSampleType target;
target[0] = static_cast<TOutputValue>(result); target[0] = static_cast<TOutputValue>(result);
return target[0]; return target[0];
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment