From e33657b4117b3867b322944e15f018c624b9ec9f Mon Sep 17 00:00:00 2001
From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr>
Date: Wed, 21 Oct 2015 14:29:26 +0200
Subject: [PATCH] ENH: simplify the vote counting

---
 .../Learning/Supervised/include/otbCvRTrees.h | 44 +++++--------------
 1 file changed, 11 insertions(+), 33 deletions(-)

diff --git a/Modules/Learning/Supervised/include/otbCvRTrees.h b/Modules/Learning/Supervised/include/otbCvRTrees.h
index 2eab1cd9f7..18da88e3d2 100644
--- a/Modules/Learning/Supervised/include/otbCvRTrees.h
+++ b/Modules/Learning/Supervised/include/otbCvRTrees.h
@@ -23,50 +23,28 @@
 
 class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees
 {
-  struct ClassVotes
-  {
-    unsigned int votes;
-    unsigned int class_idx;
-  };
-
-  struct MoreVotes
-  {
-    bool operator()(ClassVotes a, ClassVotes b)
-    {
-      return (a.votes > b.votes);
-    }
-  };
-
-  typedef std::vector<ClassVotes> ClassVotesVectorType;
-
 public:
-  CV_WRAP CvRTreesWrapper(){};
+  CvRTreesWrapper(){};
   virtual ~CvRTreesWrapper(){};
   
-  const int get_nclasses() const
-  {
-    return nclasses;
-  };
-
-  //  CV_WRAP virtual float predict(const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const{return CvRTrees::predict(sample,missing);};
-  
-  virtual const float predict_confidence(const cv::Mat& sample, const cv::Mat& missing = cv::Mat()) const
+  float predict_confidence(const cv::Mat& sample, 
+                           const cv::Mat& missing = 
+                           cv::Mat()) const
   {
-    cv::AutoBuffer<int> _votes(nclasses);
-    ClassVotesVectorType classVotes(nclasses);
+    std::vector<unsigned int> classVotes(nclasses);
     for( int k = 0; k < ntrees; k++ )
       {
       CvDTreeNode* predicted_node = trees[k]->predict( sample, missing );
       int class_idx = predicted_node->class_idx;
       CV_Assert( 0 <= class_idx && class_idx < nclasses );
-      classVotes[class_idx].votes += 1;
-      classVotes[class_idx].class_idx = class_idx;
+      ++classVotes[class_idx];
       }
     std::nth_element(classVotes.begin(), classVotes.begin()+1, 
-                     classVotes.end(), MoreVotes());
-    unsigned int maxVotes = classVotes[0].votes;
-    unsigned int secondVotes = classVotes[2].votes;
-    return static_cast<float>(maxVotes-secondVotes)/ntrees;
+                     classVotes.end(), std::greater<>());
+    unsigned int maxVotes = classVotes[0];
+    unsigned int secondVotes = classVotes[1];
+    float confidence = static_cast<float>(maxVotes-secondVotes)/ntrees;
+    return confidence;
   };
 
 };
-- 
GitLab