Commit d36baa4b authored by Manuel Grizonnet's avatar Manuel Grizonnet

Merge branch 'rfc-16-rfconfmap' into develop

parents 82894610 8469d1f2
......@@ -105,7 +105,7 @@ private:
" * KNearestNeighbors : number of neighbors with the same label\n"
" * NeuralNetwork : difference between the two highest responses\n"
" * NormalBayes : (not supported)\n"
" * RandomForest : proportion of decision trees that classified the sample to the second class (only works for 2-class models)\n"
" * RandomForest : Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n"
" * SVM : distance to margin (only works for 2-class models)\n");
SetDefaultOutputPixelType( "confmap", ImagePixelType_double);
MandatoryOff("confmap");
......
......@@ -103,7 +103,7 @@ endif()
if(OTB_USE_OPENCV)
list(APPEND classifierList "SVM" "BOOST" "DT" "GBT" "ANN" "BAYES" "RF" "KNN")
endif()
set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN")
set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF")
# Loop on classifiers
foreach(classifier ${classifierList})
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbCvRTreesWrapper_h
#define __otbCvRTreesWrapper_h
#include "otbOpenCVUtils.h"
#include <vector>
namespace otb
{
/** \class CvRTreesWrapper
* \brief Wrapper for OpenCV Random Trees
*
* \ingroup OTBSupervised
*/
class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees
{
public:
typedef std::vector<unsigned int> VotesVectorType;
CvRTreesWrapper();
virtual ~CvRTreesWrapper();
/** Compute the number of votes for each class. */
void get_votes(const cv::Mat& sample,
const cv::Mat& missing,
VotesVectorType& vote_count) const;
/** Predict the confidence of the classifcation by computing the proportion
of trees which voted for the majority class.
*/
float predict_confidence(const cv::Mat& sample,
const cv::Mat& missing =
cv::Mat()) const;
/** Predict the confidence margin of the classifcation by computing the
difference in votes between the first and second most voted classes.
This measure is preferred to the proportion of votes of the majority
class, since it provides information about the conflict between the
most likely classes.
*/
float predict_margin(const cv::Mat& sample,
const cv::Mat& missing =
cv::Mat()) const;
};
}
#endif
......@@ -24,8 +24,9 @@
#include "itkFixedArray.h"
#include "otbMachineLearningModel.h"
#include "itkVariableSizeMatrix.h"
#include "otbCvRTreesWrapper.h"
class CvRTrees;
class CvRTreesWrapper;
namespace otb
{
......@@ -53,7 +54,7 @@ public:
//opencv typedef
typedef CvRTrees RFType;
typedef CvRTreesWrapper RFType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
......@@ -120,6 +121,9 @@ public:
itkGetMacro(TerminationCriteria, int);
itkSetMacro(TerminationCriteria, int);
itkGetMacro(ComputeMargin, bool);
itkSetMacro(ComputeMargin, bool);
/** Returns a matrix containing variable importance */
VariableImportanceMatrixType GetVariableImportance();
......@@ -145,7 +149,7 @@ private:
RandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
CvRTrees * m_RFModel;
CvRTreesWrapper * m_RFModel;
/** The depth of the tree. A low value will likely underfit and conversely a
* high value will likely overfit. The optimal value can be obtained using cross
* validation or other suitable methods. */
......@@ -189,7 +193,7 @@ private:
* first category. */
std::vector<float> m_Priors;
/** If true then variable importance will be calculated and then it can be
* retrieved by CvRTrees::get_var_importance(). */
* retrieved by CvRTreesWrapper::get_var_importance(). */
bool m_CalculateVariableImportance;
/** The size of the randomly selected subset of features at each tree node and
* that are used to find the best split(s). If you set it to 0 then the size will
......@@ -205,6 +209,10 @@ private:
float m_ForestAccuracy;
/** The type of the termination criteria */
int m_TerminationCriteria;
/** Wether to compute margin (difference in probability between the
* 2 most voted classes) instead of confidence (probability of the most
* voted class) in prediction*/
bool m_ComputeMargin;
};
} // end namespace otb
......
......@@ -29,17 +29,18 @@ namespace otb
template <class TInputValue, class TOutputValue>
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::RandomForestsMachineLearningModel() :
m_RFModel (new CvRTrees),
m_MaxDepth(5),
m_MinSampleCount(10),
m_RegressionAccuracy(0.01),
m_ComputeSurrogateSplit(false),
m_MaxNumberOfCategories(10),
m_CalculateVariableImportance(false),
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)
m_RFModel (new CvRTreesWrapper),
m_MaxDepth(5),
m_MinSampleCount(10),
m_RegressionAccuracy(0.01),
m_ComputeSurrogateSplit(false),
m_MaxNumberOfCategories(10),
m_CalculateVariableImportance(false),
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
m_ComputeMargin(false)
{
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = true;
......@@ -91,7 +92,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_MaxNumberOfTrees, // max number of trees in the forest
m_ForestAccuracy, // forest accuracy
m_TerminationCriteria // termination criteria
);
);
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
......@@ -125,7 +126,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
if (quality != NULL)
{
(*quality) = m_RFModel->predict_prob(sample);
if(m_ComputeMargin)
(*quality) = m_RFModel->predict_margin(sample);
else
(*quality) = m_RFModel->predict_confidence(sample);
}
return target[0];
......@@ -158,30 +162,30 @@ bool
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const std::string & file)
{
std::ifstream ifs;
ifs.open(file.c_str());
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
//if (line.find(m_RFModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
//if (line.find(m_RFModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
{
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
return true;
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
return true;
}
}
ifs.close();
return false;
}
ifs.close();
return false;
}
template <class TInputValue, class TOutputValue>
......
set(OTBSupervised_SRC
otbCvRTreesWrapper.cxx
otbMachineLearningModelFactoryBase.cxx
otbMachineLearningUtils.cxx
)
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbCvRTreesWrapper.h"
#include <algorithm>
namespace otb
{
CvRTreesWrapper::CvRTreesWrapper(){}
CvRTreesWrapper::~CvRTreesWrapper(){}
void CvRTreesWrapper::get_votes(const cv::Mat& sample,
const cv::Mat& missing,
CvRTreesWrapper::VotesVectorType& vote_count) const
{
vote_count.resize(nclasses);
for( int k = 0; k < ntrees; k++ )
{
CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
int class_idx = predicted_node->class_idx;
CV_Assert( 0 <= class_idx && class_idx < nclasses );
++vote_count[class_idx];
}
}
float CvRTreesWrapper::predict_margin(const cv::Mat& sample,
const cv::Mat& missing) const
{
// Sanity check (division by ntrees later on)
if(ntrees == 0)
{
return 0.;
}
std::vector<unsigned int> classVotes;
this->get_votes(sample, missing, classVotes);
// We only sort the 2 greatest elements
std::nth_element(classVotes.begin(), classVotes.begin()+1,
classVotes.end(), std::greater<unsigned int>());
float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees;
return margin;
}
float CvRTreesWrapper::predict_confidence(const cv::Mat& sample,
const cv::Mat& missing) const
{
// Sanity check (division by ntrees later on)
if(ntrees == 0)
{
return 0.;
}
std::vector<unsigned int> classVotes;
this->get_votes(sample, missing, classVotes);
unsigned int max_votes = *(std::max_element(classVotes.begin(),
classVotes.end()));
float confidence = static_cast<float>(max_votes)/ntrees;
return confidence;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment