diff --git a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx index a47a006c2acd93642036f558be763c70ba57f119..ba23a0d564104525d8f173136dc46f979fff87e7 100644 --- a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx @@ -105,7 +105,7 @@ private: " * KNearestNeighbors : number of neighbors with the same label\n" " * NeuralNetwork : difference between the two highest responses\n" " * NormalBayes : (not supported)\n" - " * RandomForest : proportion of decision trees that classified the sample to the second class (only works for 2-class models)\n" + " * RandomForest : Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n" " * SVM : distance to margin (only works for 2-class models)\n"); SetDefaultOutputPixelType( "confmap", ImagePixelType_double); MandatoryOff("confmap"); diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 63a5c0918f5df9138ab25e78968b8305743f2409..58c5d6ca1cdaa6c8acf11d1a5bed0c4fc87cbe9f 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -103,7 +103,7 @@ endif() if(OTB_USE_OPENCV) list(APPEND classifierList "SVM" "BOOST" "DT" "GBT" "ANN" "BAYES" "RF" "KNN") endif() -set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN") +set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF") # Loop on classifiers foreach(classifier ${classifierList}) diff --git a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..536a7608a9720fdfa1e14e07f2d2e84707cb8f1c --- /dev/null +++ b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h @@ -0,0 +1,64 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbCvRTreesWrapper_h +#define __otbCvRTreesWrapper_h + +#include "otbOpenCVUtils.h" +#include <vector> + +namespace otb +{ + +/** \class CvRTreesWrapper + * \brief Wrapper for OpenCV Random Trees + * + * \ingroup OTBSupervised + */ +class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees +{ +public: + typedef std::vector<unsigned int> VotesVectorType; + CvRTreesWrapper(); + virtual ~CvRTreesWrapper(); + + /** Compute the number of votes for each class. */ + void get_votes(const cv::Mat& sample, + const cv::Mat& missing, + VotesVectorType& vote_count) const; + + /** Predict the confidence of the classifcation by computing the proportion + of trees which voted for the majority class. + */ + float predict_confidence(const cv::Mat& sample, + const cv::Mat& missing = + cv::Mat()) const; + + /** Predict the confidence margin of the classifcation by computing the + difference in votes between the first and second most voted classes. + This measure is preferred to the proportion of votes of the majority + class, since it provides information about the conflict between the + most likely classes. + */ + float predict_margin(const cv::Mat& sample, + const cv::Mat& missing = + cv::Mat()) const; +}; + +} + +#endif diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index ec62b761358634daa8e268c1488770bb801d3543..942e3de986956154529dbda14622de898c9d5b3e 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -24,8 +24,9 @@ #include "itkFixedArray.h" #include "otbMachineLearningModel.h" #include "itkVariableSizeMatrix.h" +#include "otbCvRTreesWrapper.h" -class CvRTrees; +class CvRTreesWrapper; namespace otb { @@ -53,7 +54,7 @@ public: //opencv typedef - typedef CvRTrees RFType; + typedef CvRTreesWrapper RFType; /** Run-time type information (and related methods). */ itkNewMacro(Self); @@ -120,6 +121,9 @@ public: itkGetMacro(TerminationCriteria, int); itkSetMacro(TerminationCriteria, int); + itkGetMacro(ComputeMargin, bool); + itkSetMacro(ComputeMargin, bool); + /** Returns a matrix containing variable importance */ VariableImportanceMatrixType GetVariableImportance(); @@ -145,7 +149,7 @@ private: RandomForestsMachineLearningModel(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented - CvRTrees * m_RFModel; + CvRTreesWrapper * m_RFModel; /** The depth of the tree. A low value will likely underfit and conversely a * high value will likely overfit. The optimal value can be obtained using cross * validation or other suitable methods. */ @@ -189,7 +193,7 @@ private: * first category. */ std::vector<float> m_Priors; /** If true then variable importance will be calculated and then it can be - * retrieved by CvRTrees::get_var_importance(). */ + * retrieved by CvRTreesWrapper::get_var_importance(). */ bool m_CalculateVariableImportance; /** The size of the randomly selected subset of features at each tree node and * that are used to find the best split(s). If you set it to 0 then the size will @@ -205,6 +209,10 @@ private: float m_ForestAccuracy; /** The type of the termination criteria */ int m_TerminationCriteria; + /** Wether to compute margin (difference in probability between the + * 2 most voted classes) instead of confidence (probability of the most + * voted class) in prediction*/ + bool m_ComputeMargin; }; } // end namespace otb diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx index 78642f1212ac9d75dc75d3bd9e9482d73bc948dc..aa0f054aa7a13940a2cd710e2adf326f01b9f3f5 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx @@ -29,17 +29,18 @@ namespace otb template <class TInputValue, class TOutputValue> RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::RandomForestsMachineLearningModel() : - m_RFModel (new CvRTrees), - m_MaxDepth(5), - m_MinSampleCount(10), - m_RegressionAccuracy(0.01), - m_ComputeSurrogateSplit(false), - m_MaxNumberOfCategories(10), - m_CalculateVariableImportance(false), - m_MaxNumberOfVariables(0), - m_MaxNumberOfTrees(100), - m_ForestAccuracy(0.01), - m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) + m_RFModel (new CvRTreesWrapper), + m_MaxDepth(5), + m_MinSampleCount(10), + m_RegressionAccuracy(0.01), + m_ComputeSurrogateSplit(false), + m_MaxNumberOfCategories(10), + m_CalculateVariableImportance(false), + m_MaxNumberOfVariables(0), + m_MaxNumberOfTrees(100), + m_ForestAccuracy(0.01), + m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS), + m_ComputeMargin(false) { this->m_ConfidenceIndex = true; this->m_IsRegressionSupported = true; @@ -91,7 +92,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_MaxNumberOfTrees, // max number of trees in the forest m_ForestAccuracy, // forest accuracy m_TerminationCriteria // termination criteria - ); + ); cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U ); var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical @@ -125,7 +126,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> if (quality != NULL) { - (*quality) = m_RFModel->predict_prob(sample); + if(m_ComputeMargin) + (*quality) = m_RFModel->predict_margin(sample); + else + (*quality) = m_RFModel->predict_confidence(sample); } return target[0]; @@ -158,30 +162,30 @@ bool RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::CanReadFile(const std::string & file) { - std::ifstream ifs; - ifs.open(file.c_str()); + std::ifstream ifs; + ifs.open(file.c_str()); - if(!ifs) - { - std::cerr<<"Could not read file "<<file<<std::endl; - return false; - } + if(!ifs) + { + std::cerr<<"Could not read file "<<file<<std::endl; + return false; + } - while (!ifs.eof()) - { - std::string line; - std::getline(ifs, line); + while (!ifs.eof()) + { + std::string line; + std::getline(ifs, line); - //if (line.find(m_RFModel->getName()) != std::string::npos) - if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) + //if (line.find(m_RFModel->getName()) != std::string::npos) + if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) { - //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; - return true; + //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; + return true; } - } - ifs.close(); - return false; + } + ifs.close(); + return false; } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Supervised/src/CMakeLists.txt b/Modules/Learning/Supervised/src/CMakeLists.txt index 67598c94a143f5b6f16327889e648994b4e928db..bab85f52ac967624f58f55953f756bcffc83fcb0 100644 --- a/Modules/Learning/Supervised/src/CMakeLists.txt +++ b/Modules/Learning/Supervised/src/CMakeLists.txt @@ -1,4 +1,5 @@ set(OTBSupervised_SRC + otbCvRTreesWrapper.cxx otbMachineLearningModelFactoryBase.cxx otbMachineLearningUtils.cxx ) diff --git a/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx new file mode 100644 index 0000000000000000000000000000000000000000..e4ac2d9f403aeb18814c1ceebc3f2bdbffec810c --- /dev/null +++ b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx @@ -0,0 +1,76 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbCvRTreesWrapper.h" +#include <algorithm> + + +namespace otb +{ + +CvRTreesWrapper::CvRTreesWrapper(){} + +CvRTreesWrapper::~CvRTreesWrapper(){} + +void CvRTreesWrapper::get_votes(const cv::Mat& sample, + const cv::Mat& missing, + CvRTreesWrapper::VotesVectorType& vote_count) const +{ + vote_count.resize(nclasses); + for( int k = 0; k < ntrees; k++ ) + { + CvDTreeNode* predicted_node = trees[k]->predict( sample, missing ); + int class_idx = predicted_node->class_idx; + CV_Assert( 0 <= class_idx && class_idx < nclasses ); + ++vote_count[class_idx]; + } +} + +float CvRTreesWrapper::predict_margin(const cv::Mat& sample, + const cv::Mat& missing) const +{ + // Sanity check (division by ntrees later on) + if(ntrees == 0) + { + return 0.; + } + std::vector<unsigned int> classVotes; + this->get_votes(sample, missing, classVotes); +// We only sort the 2 greatest elements + std::nth_element(classVotes.begin(), classVotes.begin()+1, + classVotes.end(), std::greater<unsigned int>()); + float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees; + return margin; +} + +float CvRTreesWrapper::predict_confidence(const cv::Mat& sample, + const cv::Mat& missing) const +{ + // Sanity check (division by ntrees later on) + if(ntrees == 0) + { + return 0.; + } + std::vector<unsigned int> classVotes; + this->get_votes(sample, missing, classVotes); + unsigned int max_votes = *(std::max_element(classVotes.begin(), + classVotes.end())); + float confidence = static_cast<float>(max_votes)/ntrees; + return confidence; +} + +}