Skip to content
Snippets Groups Projects
Commit da8fcaff authored by Jordi Inglada's avatar Jordi Inglada
Browse files

MRG

parents 230059ac 58bcfd76
No related branches found
No related tags found
No related merge requests found
......@@ -161,6 +161,9 @@ public:
/* The type of the termination criteria */
itkGetMacro(TerminationCriteria, int);
itkSetMacro(TerminationCriteria, int);
/* Perform regression instead of classification */
itkGetMacro(RegressionMode, bool);
itkSetMacro(RegressionMode, bool);
/** Returns a matrix containing variable importance */
VariableImportanceMatrixType GetVariableImportance();
......@@ -199,6 +202,7 @@ private:
int m_MaxNumberOfTrees;
float m_ForestAccuracy;
int m_TerminationCriteria;
bool m_RegressionMode;
};
} // end namespace otb
......
......@@ -40,7 +40,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
m_RegressionMode(false)
{
}
......@@ -95,11 +96,14 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
if(m_RegressionMode)
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL;
else
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
//train the RT model
m_RFModel->train(samples, CV_ROW_SAMPLE, labels,
cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
}
template <class TInputValue, class TOutputValue>
......@@ -113,13 +117,13 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
otb::SampleToMat<InputSampleType>(value,sample);
double result = m_RFModel->predict(sample);
double result = m_RFModel->predict(sample);
TargetSampleType target;
TargetSampleType target;
target[0] = static_cast<TOutputValue>(result);
target[0] = static_cast<TOutputValue>(result);
return target[0];
return target[0];
}
template <class TInputValue, class TOutputValue>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment