diff --git a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx index a47a006c2acd93642036f558be763c70ba57f119..ba23a0d564104525d8f173136dc46f979fff87e7 100644 --- a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx @@ -105,7 +105,7 @@ private: " * KNearestNeighbors : number of neighbors with the same label\n" " * NeuralNetwork : difference between the two highest responses\n" " * NormalBayes : (not supported)\n" - " * RandomForest : proportion of decision trees that classified the sample to the second class (only works for 2-class models)\n" + " * RandomForest : Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n" " * SVM : distance to margin (only works for 2-class models)\n"); SetDefaultOutputPixelType( "confmap", ImagePixelType_double); MandatoryOff("confmap"); diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 63a5c0918f5df9138ab25e78968b8305743f2409..58c5d6ca1cdaa6c8acf11d1a5bed0c4fc87cbe9f 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -103,7 +103,7 @@ endif() if(OTB_USE_OPENCV) list(APPEND classifierList "SVM" "BOOST" "DT" "GBT" "ANN" "BAYES" "RF" "KNN") endif() -set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN") +set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF") # Loop on classifiers foreach(classifier ${classifierList}) diff --git a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..536a7608a9720fdfa1e14e07f2d2e84707cb8f1c --- /dev/null +++ b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h @@ -0,0 +1,64 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbCvRTreesWrapper_h +#define __otbCvRTreesWrapper_h + +#include "otbOpenCVUtils.h" +#include <vector> + +namespace otb +{ + +/** \class CvRTreesWrapper + * \brief Wrapper for OpenCV Random Trees + * + * \ingroup OTBSupervised + */ +class CV_EXPORTS_W CvRTreesWrapper : public CvRTrees +{ +public: + typedef std::vector<unsigned int> VotesVectorType; + CvRTreesWrapper(); + virtual ~CvRTreesWrapper(); + + /** Compute the number of votes for each class. */ + void get_votes(const cv::Mat& sample, + const cv::Mat& missing, + VotesVectorType& vote_count) const; + + /** Predict the confidence of the classifcation by computing the proportion + of trees which voted for the majority class. + */ + float predict_confidence(const cv::Mat& sample, + const cv::Mat& missing = + cv::Mat()) const; + + /** Predict the confidence margin of the classifcation by computing the + difference in votes between the first and second most voted classes. + This measure is preferred to the proportion of votes of the majority + class, since it provides information about the conflict between the + most likely classes. + */ + float predict_margin(const cv::Mat& sample, + const cv::Mat& missing = + cv::Mat()) const; +}; + +} + +#endif diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index ec62b761358634daa8e268c1488770bb801d3543..942e3de986956154529dbda14622de898c9d5b3e 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -24,8 +24,9 @@ #include "itkFixedArray.h" #include "otbMachineLearningModel.h" #include "itkVariableSizeMatrix.h" +#include "otbCvRTreesWrapper.h" -class CvRTrees; +class CvRTreesWrapper; namespace otb { @@ -53,7 +54,7 @@ public: //opencv typedef - typedef CvRTrees RFType; + typedef CvRTreesWrapper RFType; /** Run-time type information (and related methods). */ itkNewMacro(Self); @@ -120,6 +121,9 @@ public: itkGetMacro(TerminationCriteria, int); itkSetMacro(TerminationCriteria, int); + itkGetMacro(ComputeMargin, bool); + itkSetMacro(ComputeMargin, bool); + /** Returns a matrix containing variable importance */ VariableImportanceMatrixType GetVariableImportance(); @@ -145,7 +149,7 @@ private: RandomForestsMachineLearningModel(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented - CvRTrees * m_RFModel; + CvRTreesWrapper * m_RFModel; /** The depth of the tree. A low value will likely underfit and conversely a * high value will likely overfit. The optimal value can be obtained using cross * validation or other suitable methods. */ @@ -189,7 +193,7 @@ private: * first category. */ std::vector<float> m_Priors; /** If true then variable importance will be calculated and then it can be - * retrieved by CvRTrees::get_var_importance(). */ + * retrieved by CvRTreesWrapper::get_var_importance(). */ bool m_CalculateVariableImportance; /** The size of the randomly selected subset of features at each tree node and * that are used to find the best split(s). If you set it to 0 then the size will @@ -205,6 +209,10 @@ private: float m_ForestAccuracy; /** The type of the termination criteria */ int m_TerminationCriteria; + /** Wether to compute margin (difference in probability between the + * 2 most voted classes) instead of confidence (probability of the most + * voted class) in prediction*/ + bool m_ComputeMargin; }; } // end namespace otb diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx index 78642f1212ac9d75dc75d3bd9e9482d73bc948dc..aa0f054aa7a13940a2cd710e2adf326f01b9f3f5 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx @@ -29,17 +29,18 @@ namespace otb template <class TInputValue, class TOutputValue> RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::RandomForestsMachineLearningModel() : - m_RFModel (new CvRTrees), - m_MaxDepth(5), - m_MinSampleCount(10), - m_RegressionAccuracy(0.01), - m_ComputeSurrogateSplit(false), - m_MaxNumberOfCategories(10), - m_CalculateVariableImportance(false), - m_MaxNumberOfVariables(0), - m_MaxNumberOfTrees(100), - m_ForestAccuracy(0.01), - m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS) + m_RFModel (new CvRTreesWrapper), + m_MaxDepth(5), + m_MinSampleCount(10), + m_RegressionAccuracy(0.01), + m_ComputeSurrogateSplit(false), + m_MaxNumberOfCategories(10), + m_CalculateVariableImportance(false), + m_MaxNumberOfVariables(0), + m_MaxNumberOfTrees(100), + m_ForestAccuracy(0.01), + m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS), + m_ComputeMargin(false) { this->m_ConfidenceIndex = true; this->m_IsRegressionSupported = true; @@ -91,7 +92,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_MaxNumberOfTrees, // max number of trees in the forest m_ForestAccuracy, // forest accuracy m_TerminationCriteria // termination criteria - ); + ); cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U ); var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical @@ -125,7 +126,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> if (quality != NULL) { - (*quality) = m_RFModel->predict_prob(sample); + if(m_ComputeMargin) + (*quality) = m_RFModel->predict_margin(sample); + else + (*quality) = m_RFModel->predict_confidence(sample); } return target[0]; @@ -158,30 +162,30 @@ bool RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::CanReadFile(const std::string & file) { - std::ifstream ifs; - ifs.open(file.c_str()); + std::ifstream ifs; + ifs.open(file.c_str()); - if(!ifs) - { - std::cerr<<"Could not read file "<<file<<std::endl; - return false; - } + if(!ifs) + { + std::cerr<<"Could not read file "<<file<<std::endl; + return false; + } - while (!ifs.eof()) - { - std::string line; - std::getline(ifs, line); + while (!ifs.eof()) + { + std::string line; + std::getline(ifs, line); - //if (line.find(m_RFModel->getName()) != std::string::npos) - if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) + //if (line.find(m_RFModel->getName()) != std::string::npos) + if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos) { - //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; - return true; + //std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl; + return true; } - } - ifs.close(); - return false; + } + ifs.close(); + return false; } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Supervised/src/CMakeLists.txt b/Modules/Learning/Supervised/src/CMakeLists.txt index 67598c94a143f5b6f16327889e648994b4e928db..bab85f52ac967624f58f55953f756bcffc83fcb0 100644 --- a/Modules/Learning/Supervised/src/CMakeLists.txt +++ b/Modules/Learning/Supervised/src/CMakeLists.txt @@ -1,4 +1,5 @@ set(OTBSupervised_SRC + otbCvRTreesWrapper.cxx otbMachineLearningModelFactoryBase.cxx otbMachineLearningUtils.cxx ) diff --git a/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx new file mode 100644 index 0000000000000000000000000000000000000000..e4ac2d9f403aeb18814c1ceebc3f2bdbffec810c --- /dev/null +++ b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx @@ -0,0 +1,76 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbCvRTreesWrapper.h" +#include <algorithm> + + +namespace otb +{ + +CvRTreesWrapper::CvRTreesWrapper(){} + +CvRTreesWrapper::~CvRTreesWrapper(){} + +void CvRTreesWrapper::get_votes(const cv::Mat& sample, + const cv::Mat& missing, + CvRTreesWrapper::VotesVectorType& vote_count) const +{ + vote_count.resize(nclasses); + for( int k = 0; k < ntrees; k++ ) + { + CvDTreeNode* predicted_node = trees[k]->predict( sample, missing ); + int class_idx = predicted_node->class_idx; + CV_Assert( 0 <= class_idx && class_idx < nclasses ); + ++vote_count[class_idx]; + } +} + +float CvRTreesWrapper::predict_margin(const cv::Mat& sample, + const cv::Mat& missing) const +{ + // Sanity check (division by ntrees later on) + if(ntrees == 0) + { + return 0.; + } + std::vector<unsigned int> classVotes; + this->get_votes(sample, missing, classVotes); +// We only sort the 2 greatest elements + std::nth_element(classVotes.begin(), classVotes.begin()+1, + classVotes.end(), std::greater<unsigned int>()); + float margin = static_cast<float>(classVotes[0]-classVotes[1])/ntrees; + return margin; +} + +float CvRTreesWrapper::predict_confidence(const cv::Mat& sample, + const cv::Mat& missing) const +{ + // Sanity check (division by ntrees later on) + if(ntrees == 0) + { + return 0.; + } + std::vector<unsigned int> classVotes; + this->get_votes(sample, missing, classVotes); + unsigned int max_votes = *(std::max_element(classVotes.begin(), + classVotes.end())); + float confidence = static_cast<float>(max_votes)/ntrees; + return confidence; +} + +} diff --git a/Modules/Segmentation/OGRProcessing/include/otbOGRLayerStreamStitchingFilter.txx b/Modules/Segmentation/OGRProcessing/include/otbOGRLayerStreamStitchingFilter.txx index 2e3e56f9e399ab5a38ace84471207d347f7a9967..b1a88bff25b7bbae67f41416a400136cebb273ee 100644 --- a/Modules/Segmentation/OGRProcessing/include/otbOGRLayerStreamStitchingFilter.txx +++ b/Modules/Segmentation/OGRProcessing/include/otbOGRLayerStreamStitchingFilter.txx @@ -318,8 +318,22 @@ OGRLayerStreamStitchingFilter<TInputImage> try { #ifdef OTB_USE_GDAL_20 - fusionFeature[0].SetValue(field.GetValue<GIntBig>()); + // In this case, the feature id can be either + // OFTInteger64 or OFTInteger + switch(field.GetType()) + { + case OFTInteger64: + { + fusionFeature[0].SetValue(field.GetValue<GIntBig>()); + break; + } + default: + { + fusionFeature[0].SetValue(field.GetValue<int>()); + } + } #else + // Only OFTInteger supported in this case fusionFeature[0].SetValue(field.GetValue<int>()); #endif m_OGRLayer.CreateFeature(fusionFeature); diff --git a/Modules/ThirdParty/ITK/include/otbWarpImageFilter.h b/Modules/ThirdParty/ITK/include/otbWarpImageFilter.h index 0a2041d60b3796a6ec2cd007bfe4335370a15c31..19f46d9a65ff4f4e1a6abc8aa8870e3a8de0f3d9 100644 --- a/Modules/ThirdParty/ITK/include/otbWarpImageFilter.h +++ b/Modules/ThirdParty/ITK/include/otbWarpImageFilter.h @@ -257,7 +257,7 @@ private: /** This function should be in an interpolator but none of the ITK * interpolators at this point handle edge conditions properly */ - DisplacementType EvaluateDisplacementAtPhysicalPoint(const PointType &p); + DisplacementType EvaluateDisplacementAtPhysicalPoint(const PointType &p, const DisplacementFieldType *fieldPtr); PixelType m_EdgePaddingValue; SpacingType m_OutputSpacing; diff --git a/Modules/ThirdParty/ITK/include/otbWarpImageFilter.txx b/Modules/ThirdParty/ITK/include/otbWarpImageFilter.txx index a57a72f045d97f8ab7916a5e7835a48890f4d794..7f0719b187d0e3007bdf965eaa39c8b28fa18ce6 100644 --- a/Modules/ThirdParty/ITK/include/otbWarpImageFilter.txx +++ b/Modules/ThirdParty/ITK/include/otbWarpImageFilter.txx @@ -214,9 +214,8 @@ typename WarpImageFilter<TInputImage, TOutputImage, TDisplacementField>::DisplacementType WarpImageFilter<TInputImage,TOutputImage,TDisplacementField> -::EvaluateDisplacementAtPhysicalPoint(const PointType &point) +::EvaluateDisplacementAtPhysicalPoint(const PointType &point, const DisplacementFieldType *fieldPtr) { - DisplacementFieldPointer fieldPtr = this->GetDisplacementField(); itk::ContinuousIndex<double,ImageDimension> index; fieldPtr->TransformPhysicalPointToContinuousIndex(point,index); unsigned int dim; // index over dimension @@ -377,7 +376,7 @@ WarpImageFilter<TInputImage,TOutputImage,TDisplacementField> index = outputIt.GetIndex(); outputPtr->TransformIndexToPhysicalPoint( index, point ); - displacement = this->EvaluateDisplacementAtPhysicalPoint(point); + displacement = this->EvaluateDisplacementAtPhysicalPoint( point, fieldPtr ); // compute the required input image point for(unsigned int j = 0; j < ImageDimension; j++ ) { diff --git a/Modules/Wrappers/ApplicationEngine/include/otbWrapperInputProcessXMLParameter.h b/Modules/Wrappers/ApplicationEngine/include/otbWrapperInputProcessXMLParameter.h index be5f9dc2154e8f13ebee5cefc1f14fb4dfa60112..1db4f6ecc2cacbee7ee630af05c5a61a0030b6cf 100644 --- a/Modules/Wrappers/ApplicationEngine/include/otbWrapperInputProcessXMLParameter.h +++ b/Modules/Wrappers/ApplicationEngine/include/otbWrapperInputProcessXMLParameter.h @@ -43,19 +43,10 @@ public: // Get Value //TODO otbGetObjectMemberMacro(StringParam, Value , std::string); - void SetFileName(std::string value) - { - this->SetValue(value); - } + bool SetFileName(std::string value); // Set Value - virtual void SetValue(const std::string value) - { - itkDebugMacro("setting member m_FileName to " << value); - this->m_FileName = value; - SetActive(true); - this->Modified(); - } + virtual void SetValue(const std::string value); ImagePixelType GetPixelTypeFromString(std::string pixTypeAsString); diff --git a/Modules/Wrappers/ApplicationEngine/src/otbWrapperApplication.cxx b/Modules/Wrappers/ApplicationEngine/src/otbWrapperApplication.cxx index 2612a7cb97f475038c835f2b678e4fcefe6508b0..3df1c1937b3efdacfef62d4d37cacb606cf4174d 100644 --- a/Modules/Wrappers/ApplicationEngine/src/otbWrapperApplication.cxx +++ b/Modules/Wrappers/ApplicationEngine/src/otbWrapperApplication.cxx @@ -830,7 +830,8 @@ void Application::SetParameterString(std::string parameter, std::string value) else if (dynamic_cast<InputProcessXMLParameter*>(param)) { InputProcessXMLParameter* paramDown = dynamic_cast<InputProcessXMLParameter*>(param); - paramDown->SetValue(value); + if ( !paramDown->SetFileName(value) ) + otbAppLogCRITICAL( <<"Invalid XML parameter filename " << value <<"."); } } diff --git a/Modules/Wrappers/ApplicationEngine/src/otbWrapperInputProcessXMLParameter.cxx b/Modules/Wrappers/ApplicationEngine/src/otbWrapperInputProcessXMLParameter.cxx index 87b94e868d178446e6996e216751dbb502300711..be17bb948f81ea97a661a1e7fb6165db44f0ccd3 100644 --- a/Modules/Wrappers/ApplicationEngine/src/otbWrapperInputProcessXMLParameter.cxx +++ b/Modules/Wrappers/ApplicationEngine/src/otbWrapperInputProcessXMLParameter.cxx @@ -41,6 +41,34 @@ InputProcessXMLParameter::~InputProcessXMLParameter() } +bool +InputProcessXMLParameter::SetFileName(std::string value) +{ + // Check if the filename is not empty + if(!value.empty()) + { + // Check that the right extension is given : expected .xml + if (itksys::SystemTools::GetFilenameLastExtension(value) == ".xml") + { + if (itksys::SystemTools::FileExists(value.c_str(),true)) + { + this->SetValue(value); + return true; + } + } + } + return false; +} + +void +InputProcessXMLParameter::SetValue(const std::string value) +{ + itkDebugMacro("setting member m_FileName to " << value); + this->m_FileName = value; + SetActive(true); + this->Modified(); +} + ImagePixelType InputProcessXMLParameter::GetPixelTypeFromString(std::string strType) { @@ -134,18 +162,6 @@ InputProcessXMLParameter::GetChildNodeTextOf(TiXmlElement *parentElement, std::s int InputProcessXMLParameter::Read(Application::Pointer this_) { - - // Check if the filename is not empty - if(m_FileName.empty()) - itkExceptionMacro(<<"The XML input FileName is empty, please set the filename via the method SetFileName"); - - // Check that the right extension is given : expected .xml - if (itksys::SystemTools::GetFilenameLastExtension(m_FileName) != ".xml") - { - itkExceptionMacro(<<itksys::SystemTools::GetFilenameLastExtension(m_FileName) << " " << m_FileName << " " - <<" is a wrong Extension FileName : Expected .xml"); - } - // Open the xml file TiXmlDocument doc; diff --git a/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputProcessXMLParameter.cxx b/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputProcessXMLParameter.cxx index 9703946e3b9102023df876716bf239f5d000f3b3..c30e5e33b08df629336f8023aaa6a5f33e2463ea 100644 --- a/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputProcessXMLParameter.cxx +++ b/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputProcessXMLParameter.cxx @@ -34,10 +34,13 @@ QtWidgetInputProcessXMLParameter::~QtWidgetInputProcessXMLParameter() void QtWidgetInputProcessXMLParameter::DoUpdateGUI() { - // Update the lineEdit - QString text( m_XMLParam->GetFileName() ); - if (text != m_Input->text()) - m_Input->setText(text); + if (m_XMLParam->HasUserValue()) + { + // Update the lineEdit + QString text( m_XMLParam->GetFileName() ); + if (text != m_Input->text()) + m_Input->setText(text); + } } void QtWidgetInputProcessXMLParameter::DoCreateWidget() @@ -98,15 +101,13 @@ void QtWidgetInputProcessXMLParameter::SelectFile() void QtWidgetInputProcessXMLParameter::SetFileName(const QString& value) { // load xml file name - m_XMLParam->SetValue(value.toAscii().constData()); - - // notify of value change - QString key( m_XMLParam->GetKey() ); - - emit ParameterChanged(key); - - GetModel()->UpdateAllWidgets(); - + if (m_XMLParam->SetFileName(value.toAscii().constData())) + { + // notify of value change + QString key( m_XMLParam->GetKey() ); + emit ParameterChanged(key); + GetModel()->UpdateAllWidgets(); + } } } diff --git a/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputVectorDataParameter.cxx b/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputVectorDataParameter.cxx index 70018c01e4490d7b153306f686f68ceb26e96a15..9aaa6e857abdced364377e4c06e5798428d25eb2 100644 --- a/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputVectorDataParameter.cxx +++ b/Modules/Wrappers/QtWidget/src/otbWrapperQtWidgetInputVectorDataParameter.cxx @@ -38,9 +38,12 @@ QtWidgetInputVectorDataParameter::~QtWidgetInputVectorDataParameter() void QtWidgetInputVectorDataParameter::DoUpdateGUI() { //update lineedit - QString text( m_InputVectorDataParam->GetFileName().c_str() ); - if (text != m_Input->text()) - m_Input->setText(text); + if(m_InputVectorDataParam->HasUserValue()) + { + QString text( m_InputVectorDataParam->GetFileName().c_str() ); + if (text != m_Input->text()) + m_Input->setText(text); + } } void QtWidgetInputVectorDataParameter::DoCreateWidget()