Commit 7ee4d51a authored by Jordi Inglada's avatar Jordi Inglada

ENH: Distinguish between margin and confidence

Two different methods for computing margin and confidence in random
forests.
Refactor the vote count in a separate method.
parent 679cf025
......@@ -19,6 +19,7 @@
#define __otbCvRTreesWrapper_h
#include "otbOpenCVUtils.h"
#include <vector>
namespace otb
{
......@@ -31,18 +32,31 @@ namespace otb
class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees
{
public:
typedef std::vector<unsigned int> VotesVectorType;
CvRTreesWrapper();
virtual ~CvRTreesWrapper();
/** Compute the number of votes for each class. */
void get_votes(const cv::Mat& sample,
const cv::Mat& missing,
VotesVectorType& vote_count) const;
/** Predict the confidence of the classifcation by computing the
/** Predict the confidence of the classifcation by computing the proportion
of trees which voted for the majority class.
*/
float predict_confidence(const cv::Mat& sample,
const cv::Mat& missing =
cv::Mat()) const;
/** Predict the confidence margin of the classifcation by computing the
difference in votes between the first and second most voted classes.
This measure is preferred to the proportion of votes of the majority
class, since it provides information about the conflict between the
most likely classes.
*/
float predict_confidence(const cv::Mat& sample,
const cv::Mat& missing =
cv::Mat()) const;
float predict_margin(const cv::Mat& sample,
const cv::Mat& missing =
cv::Mat()) const;
};
}
......
......@@ -16,7 +16,6 @@
=========================================================================*/
#include "otbCvRTreesWrapper.h"
#include <vector>
#include <algorithm>
......@@ -27,27 +26,51 @@ CvRTreesWrapper::CvRTreesWrapper(){}
CvRTreesWrapper::~CvRTreesWrapper(){}
void CvRTreesWrapper::get_votes(const cv::Mat& sample,
const cv::Mat& missing,
CvRTreesWrapper::VotesVectorType& vote_count) const
{
vote_count.resize(nclasses);
for( int k = 0; k < ntrees; k++ )
{
CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
int class_idx = predicted_node->class_idx;
CV_Assert( 0 <= class_idx && class_idx < nclasses );
++vote_count[class_idx];
}
}
float CvRTreesWrapper::predict_margin(const cv::Mat& sample,
const cv::Mat& missing) const
{
// Sanity check (division by ntrees later on)
if(ntrees == 0)
{
return 0.;
}
std::vector<unsigned int> classVotes;
this->get_votes(sample, missing, classVotes);
// We only sort the 2 greatest elements
std::nth_element(classVotes.begin(), classVotes.begin()+1,
classVotes.end(), std::greater<unsigned int>());
float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees;
return margin;
}
float CvRTreesWrapper::predict_confidence(const cv::Mat& sample,
const cv::Mat& missing) const
{
// Sanity check (division by ntrees later on)
if(ntrees == 0)
{
return 0.;
}
std::vector<unsigned int> classVotes(nclasses);
for( int k = 0; k < ntrees; k++ )
{
CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
int class_idx = predicted_node->class_idx;
CV_Assert( 0 <= class_idx && class_idx < nclasses );
++classVotes[class_idx];
}
// We only sort the 2 greatest elements
std::nth_element(classVotes.begin(), classVotes.begin()+1,
classVotes.end(), std::greater<unsigned int>());
float confidence = static_cast<float>(classVotes[0]-classVotes[1])/ntrees;
return confidence;
}
const cv::Mat& missing) const
{
// Sanity check (division by ntrees later on)
if(ntrees == 0)
{
return 0.;
}
std::vector<unsigned int> classVotes;
this->get_votes(sample, missing, classVotes);
unsigned int max_votes = *(std::max_element(classVotes.begin(),
classVotes.end()));
float confidence = static_cast<float>(max_votes)/ntrees;
return confidence;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment