Skip to content
Snippets Groups Projects
Commit 678d67ec authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: implement margin and confidence prediction for CvRTreesWrapper

parent d2ba021f
Branches
Tags
No related merge requests found
...@@ -38,10 +38,55 @@ void CvRTreesWrapper::get_votes(const cv::Mat& sample, ...@@ -38,10 +38,55 @@ void CvRTreesWrapper::get_votes(const cv::Mat& sample,
CvRTreesWrapper::VotesVectorType& vote_count) const CvRTreesWrapper::VotesVectorType& vote_count) const
{ {
#ifdef OTB_OPENCV_3 #ifdef OTB_OPENCV_3
(void) sample; // missing samples not implemented yet
(void) missing; (void) missing;
(void) vote_count;
// TODO // Here we have to re-implement a basic "predict_tree()" since the function is
// not exposed anymore
const std::vector< cv::ml::DTrees::Node > &nodes = m_Impl->getNodes();
const std::vector< cv::ml::DTrees::Split > &splits = m_Impl->getSplits();
const std::vector<int> &roots = m_Impl->getRoots();
int ntrees = roots.size();
int nodeIdx, prevNodeIdx;
int predictedClass = -1;
const float* samplePtr = sample.ptr<float>();
std::map<int, unsigned int> votes;
for (int t=0; t<ntrees ; t++)
{
nodeIdx = roots[t];
prevNodeIdx = nodeIdx;
while(1)
{
prevNodeIdx = nodeIdx;
const cv::ml::DTrees::Node &curNode = nodes[nodeIdx];
// test if this node is a leaf
if (curNode.split < 0)
break;
const cv::ml::DTrees::Split& split = splits[curNode.split];
int varIdx = split.varIdx;
float val = samplePtr[varIdx];
nodeIdx = val <= split.c ? curNode.left : curNode.right;
}
predictedClass = nodes[prevNodeIdx].classIdx;
votes[predictedClass] += 1;
}
vote_count.resize(votes.size());
int pos=0;
for (std::map<int, unsigned int>::const_iterator it=votes.begin() ;
it != votes.end() ;
++it)
{
vote_count[pos] = it->second;
pos++;
}
if (vote_count.size() == 1)
{
// give at least 2 classes
vote_count.push_back(0);
}
#else #else
vote_count.resize(nclasses); vote_count.resize(nclasses);
for( int k = 0; k < ntrees; k++ ) for( int k = 0; k < ntrees; k++ )
...@@ -58,11 +103,8 @@ float CvRTreesWrapper::predict_margin(const cv::Mat& sample, ...@@ -58,11 +103,8 @@ float CvRTreesWrapper::predict_margin(const cv::Mat& sample,
const cv::Mat& missing) const const cv::Mat& missing) const
{ {
#ifdef OTB_OPENCV_3 #ifdef OTB_OPENCV_3
(void) sample; int ntrees = m_Impl->getRoots().size();
(void) missing; #endif
// TODO
return 0.;
#else
// Sanity check (division by ntrees later on) // Sanity check (division by ntrees later on)
if(ntrees == 0) if(ntrees == 0)
{ {
...@@ -75,18 +117,14 @@ float CvRTreesWrapper::predict_margin(const cv::Mat& sample, ...@@ -75,18 +117,14 @@ float CvRTreesWrapper::predict_margin(const cv::Mat& sample,
classVotes.end(), std::greater<unsigned int>()); classVotes.end(), std::greater<unsigned int>());
float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees; float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees;
return margin; return margin;
#endif
} }
float CvRTreesWrapper::predict_confidence(const cv::Mat& sample, float CvRTreesWrapper::predict_confidence(const cv::Mat& sample,
const cv::Mat& missing) const const cv::Mat& missing) const
{ {
#ifdef OTB_OPENCV_3 #ifdef OTB_OPENCV_3
(void) sample; int ntrees = m_Impl->getRoots().size();
(void) missing; #endif
// TODO
return 0.;
#else
// Sanity check (division by ntrees later on) // Sanity check (division by ntrees later on)
if(ntrees == 0) if(ntrees == 0)
{ {
...@@ -98,7 +136,6 @@ float CvRTreesWrapper::predict_confidence(const cv::Mat& sample, ...@@ -98,7 +136,6 @@ float CvRTreesWrapper::predict_confidence(const cv::Mat& sample,
classVotes.end())); classVotes.end()));
float confidence = static_cast<float>(max_votes)/ntrees; float confidence = static_cast<float>(max_votes)/ntrees;
return confidence; return confidence;
#endif
} }
#ifdef OTB_OPENCV_3 #ifdef OTB_OPENCV_3
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment