diff --git a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h index c989bc36e22f811881ecf64982ff45118137774c..536a7608a9720fdfa1e14e07f2d2e84707cb8f1c 100644 --- a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h +++ b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h @@ -19,6 +19,7 @@ #define __otbCvRTreesWrapper_h #include "otbOpenCVUtils.h" +#include <vector> namespace otb { @@ -31,18 +32,31 @@ namespace otb class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees { public: + typedef std::vector<unsigned int> VotesVectorType; CvRTreesWrapper(); virtual ~CvRTreesWrapper(); + + /** Compute the number of votes for each class. */ + void get_votes(const cv::Mat& sample, + const cv::Mat& missing, + VotesVectorType& vote_count) const; - /** Predict the confidence of the classifcation by computing the + /** Predict the confidence of the classifcation by computing the proportion + of trees which voted for the majority class. + */ + float predict_confidence(const cv::Mat& sample, + const cv::Mat& missing = + cv::Mat()) const; + + /** Predict the confidence margin of the classifcation by computing the difference in votes between the first and second most voted classes. This measure is preferred to the proportion of votes of the majority class, since it provides information about the conflict between the most likely classes. */ - float predict_confidence(const cv::Mat& sample, - const cv::Mat& missing = - cv::Mat()) const; + float predict_margin(const cv::Mat& sample, + const cv::Mat& missing = + cv::Mat()) const; }; } diff --git a/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx index 9207e0d33753c30c55ff6fe4d355af84e550898d..e4ac2d9f403aeb18814c1ceebc3f2bdbffec810c 100644 --- a/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx +++ b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx @@ -16,7 +16,6 @@ =========================================================================*/ #include "otbCvRTreesWrapper.h" -#include <vector> #include <algorithm> @@ -27,27 +26,51 @@ CvRTreesWrapper::CvRTreesWrapper(){} CvRTreesWrapper::~CvRTreesWrapper(){} +void CvRTreesWrapper::get_votes(const cv::Mat& sample, + const cv::Mat& missing, + CvRTreesWrapper::VotesVectorType& vote_count) const +{ + vote_count.resize(nclasses); + for( int k = 0; k < ntrees; k++ ) + { + CvDTreeNode* predicted_node = trees[k]->predict( sample, missing ); + int class_idx = predicted_node->class_idx; + CV_Assert( 0 <= class_idx && class_idx < nclasses ); + ++vote_count[class_idx]; + } +} + +float CvRTreesWrapper::predict_margin(const cv::Mat& sample, + const cv::Mat& missing) const +{ + // Sanity check (division by ntrees later on) + if(ntrees == 0) + { + return 0.; + } + std::vector<unsigned int> classVotes; + this->get_votes(sample, missing, classVotes); +// We only sort the 2 greatest elements + std::nth_element(classVotes.begin(), classVotes.begin()+1, + classVotes.end(), std::greater<unsigned int>()); + float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees; + return margin; +} + float CvRTreesWrapper::predict_confidence(const cv::Mat& sample, - const cv::Mat& missing) const - { - // Sanity check (division by ntrees later on) - if(ntrees == 0) - { - return 0.; - } - - std::vector<unsigned int> classVotes(nclasses); - for( int k = 0; k < ntrees; k++ ) - { - CvDTreeNode* predicted_node = trees[k]->predict( sample, missing ); - int class_idx = predicted_node->class_idx; - CV_Assert( 0 <= class_idx && class_idx < nclasses ); - ++classVotes[class_idx]; - } - // We only sort the 2 greatest elements - std::nth_element(classVotes.begin(), classVotes.begin()+1, - classVotes.end(), std::greater<unsigned int>()); - float confidence = static_cast<float>(classVotes[0]-classVotes[1])/ntrees; - return confidence; - } + const cv::Mat& missing) const +{ + // Sanity check (division by ntrees later on) + if(ntrees == 0) + { + return 0.; + } + std::vector<unsigned int> classVotes; + this->get_votes(sample, missing, classVotes); + unsigned int max_votes = *(std::max_element(classVotes.begin(), + classVotes.end())); + float confidence = static_cast<float>(max_votes)/ntrees; + return confidence; +} + }