Skip to content
Snippets Groups Projects
Commit ecd14f9f authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: use the random fores confidence estimation in the MachineLearningModel

parent d5482b5c
No related branches found
No related tags found
No related merge requests found
......@@ -24,8 +24,9 @@
#include "itkFixedArray.h"
#include "otbMachineLearningModel.h"
#include "itkVariableSizeMatrix.h"
#include "otbCvRTrees.h"
class CvRTrees;
class CvRTreesWrapper;
namespace otb
{
......@@ -53,7 +54,7 @@ public:
//opencv typedef
typedef CvRTrees RFType;
typedef CvRTreesWrapper RFType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
......@@ -145,7 +146,7 @@ private:
RandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
CvRTrees * m_RFModel;
CvRTreesWrapper * m_RFModel;
/** The depth of the tree. A low value will likely underfit and conversely a
* high value will likely overfit. The optimal value can be obtained using cross
* validation or other suitable methods. */
......@@ -189,7 +190,7 @@ private:
* first category. */
std::vector<float> m_Priors;
/** If true then variable importance will be calculated and then it can be
* retrieved by CvRTrees::get_var_importance(). */
* retrieved by CvRTreesWrapper::get_var_importance(). */
bool m_CalculateVariableImportance;
/** The size of the randomly selected subset of features at each tree node and
* that are used to find the best split(s). If you set it to 0 then the size will
......
......@@ -29,17 +29,17 @@ namespace otb
template <class TInputValue, class TOutputValue>
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::RandomForestsMachineLearningModel() :
m_RFModel (new CvRTrees),
m_MaxDepth(5),
m_MinSampleCount(10),
m_RegressionAccuracy(0.01),
m_ComputeSurrogateSplit(false),
m_MaxNumberOfCategories(10),
m_CalculateVariableImportance(false),
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
m_RFModel (new CvRTreesWrapper),
m_MaxDepth(5),
m_MinSampleCount(10),
m_RegressionAccuracy(0.01),
m_ComputeSurrogateSplit(false),
m_MaxNumberOfCategories(10),
m_CalculateVariableImportance(false),
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
{
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = true;
......@@ -125,7 +125,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
if (quality != NULL)
{
(*quality) = m_RFModel->predict_prob(sample);
(*quality) = m_RFModel->predict_confidence(sample);
}
return target[0];
......@@ -158,23 +158,23 @@ bool
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const std::string & file)
{
std::ifstream ifs;
ifs.open(file.c_str());
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
//if (line.find(m_RFModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
//if (line.find(m_RFModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
{
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
return true;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment