Commit a809c122 authored by Jordi Inglada's avatar Jordi Inglada

ENH: add option to choose between margin and confidence in RF classification

parent 7ee4d51a
......@@ -121,6 +121,9 @@ public:
itkGetMacro(TerminationCriteria, int);
itkSetMacro(TerminationCriteria, int);
itkGetMacro(CoputeMargin, bool);
itkSetMacro(ComputeMargin, bool);
/** Returns a matrix containing variable importance */
VariableImportanceMatrixType GetVariableImportance();
......@@ -206,6 +209,10 @@ private:
float m_ForestAccuracy;
/** The type of the termination criteria */
int m_TerminationCriteria;
/** Wether to compute margin (difference in probability between the
* 2 most voted classes) instead of confidence (probability of the most
* voted class) in prediction*/
bool m_ComputeMargin;
};
} // end namespace otb
......
......@@ -39,7 +39,8 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
m_ComputeMargin(false)
{
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = true;
......@@ -91,7 +92,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_MaxNumberOfTrees, // max number of trees in the forest
m_ForestAccuracy, // forest accuracy
m_TerminationCriteria // termination criteria
);
);
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
......@@ -125,7 +126,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
if (quality != NULL)
{
(*quality) = m_RFModel->predict_confidence(sample);
if(m_ComputeMargin)
(*quality) = m_RFModel->predict_margin(sample);
else
(*quality) = m_RFModel->predict_confidence(sample);
}
return target[0];
......@@ -176,12 +180,12 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
//if (line.find(m_RFModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
{
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
return true;
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
return true;
}
}
ifs.close();
return false;
}
ifs.close();
return false;
}
template <class TInputValue, class TOutputValue>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment