diff --git a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx index 07a8a352d7622e29bed60592a69fed040ad53a8c..8cd4c5b030159f38a69b20e571dae27e149a5e2c 100644 --- a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx @@ -59,6 +59,7 @@ public: typedef ClassificationFilterType::ValueType ValueType; typedef ClassificationFilterType::LabelType LabelType; typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; + typedef ClassificationFilterType::ConfidenceImageType ConfidenceImageType; private: void DoInit() @@ -94,6 +95,21 @@ private: SetParameterDescription( "out", "Output image containing class labels"); SetParameterOutputImagePixelType( "out", ImagePixelType_uint8); + AddParameter(ParameterType_OutputImage, "confmap", "Confidence map"); + SetParameterDescription( "confmap", "Confidence map of the produced classification. The confidence index depends on the model : \n" + " - LibSVM : difference between the two highest probabilities (needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n" + " - OpenCV\n" + " * Boost : sum of votes\n" + " * DecisionTree : (not supported)\n" + " * GradientBoostedTree : (not supported)\n" + " * KNearestNeighbors : number of neighbors with the same label\n" + " * NeuralNetwork : difference between the two highest responses\n" + " * NormalBayes : (not supported)\n" + " * RandomForest : proportion of decision trees that classified the sample to the second class (only works for 2-class models)\n" + " * SVM : distance to margin (only works for 2-class models)\n"); + SetParameterOutputImagePixelType( "confmap", ImagePixelType_double); + MandatoryOff("confmap"); + AddRAMParameter(); // Doc example parameter settings @@ -171,6 +187,21 @@ private: } SetParameterOutputImage<OutputImageType>("out", m_ClassificationFilter->GetOutput()); + + // output confidence map + if (IsParameterEnabled("confmap") && HasValue("confmap")) + { + m_ClassificationFilter->SetUseConfidenceMap(true); + if (m_Model->HasConfidenceIndex()) + { + SetParameterOutputImage<ConfidenceImageType>("confmap",m_ClassificationFilter->GetOutputConfidence()); + } + else + { + otbAppLogWARNING("Confidence map requested but the classifier doesn't support it!"); + this->DisableParameter("confmap"); + } + } } ClassificationFilterType::Pointer m_ClassificationFilter; diff --git a/Modules/Applications/AppClassification/app/otbTrainLibSVM.cxx b/Modules/Applications/AppClassification/app/otbTrainLibSVM.cxx index 458e5b8487b8f6e11c337c3ffc0026941c8ca627..e117337708d3effd267a462a1d24f8b2f6736e18 100644 --- a/Modules/Applications/AppClassification/app/otbTrainLibSVM.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainLibSVM.cxx @@ -42,6 +42,9 @@ namespace Wrapper AddParameter(ParameterType_Empty, "classifier.libsvm.opt", "Parameters optimization"); MandatoryOff("classifier.libsvm.opt"); SetParameterDescription("classifier.libsvm.opt", "SVM parameters optimization flag."); + AddParameter(ParameterType_Empty, "classifier.libsvm.prob", "Probability estimation"); + MandatoryOff("classifier.libsvm.prob"); + SetParameterDescription("classifier.libsvm.prob", "Probability estimation flag."); } @@ -56,6 +59,10 @@ namespace Wrapper { libSVMClassifier->SetParameterOptimization(true); } + if (IsParameterEnabled("classifier.libsvm.prob")) + { + libSVMClassifier->SetDoProbabilityEstimates(true); + } libSVMClassifier->SetC(GetParameterFloat("classifier.libsvm.c")); switch (GetParameterInt("classifier.libsvm.k")) diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 77e4a3636495955d2488546da105d7dc2c97c379..c4e275fe5a33fc8e629a13c9016a5584c8a4b0a3 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -77,7 +77,7 @@ set(rf_output_format ".rf") set(knn_output_format ".knn") # Training algorithms parameters -set(libsvm_parameters "-classifier.libsvm.opt" "true") +set(libsvm_parameters "-classifier.libsvm.opt" "true" "-classifier.libsvm.prob" "true") set(svm_parameters "-classifier.svm.opt" "true") set(boost_parameters "") set(dt_parameters "") @@ -90,6 +90,7 @@ set(knn_parameters "") # Validation depending on mode set(ascii_comparison --compare-ascii ${NOTOL}) set(raster_comparison --compare-image ${NOTOL}) +set(raster_comparison_two --compare-n-images ${NOTOL} 2) # Reference ffiles depending on modes set(ascii_ref_path ${OTBAPP_BASELINE_FILES}) @@ -102,6 +103,7 @@ endif() if(OTB_USE_OPENCV) list(APPEND classifierList "SVM" "BOOST" "DT" "GBT" "ANN" "BAYES" "RF" "KNN") endif() +set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN") # Loop on classifiers foreach(classifier ${classifierList}) @@ -110,6 +112,7 @@ foreach(classifier ${classifierList}) # Derive output file name set(OUTMODELFILE cl${classifier}_ModelQB1${${lclassifier}_output_format}) set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format}) + set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format}) otb_test_application( NAME apTvClTrainMethod${classifier}ImagesClassifierQB1 @@ -160,7 +163,9 @@ foreach(classifier ${classifierList}) #set_tests_properties(apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 PROPERTIES DEPENDS apTvClTrainMethod${classifier}ImagesClassifierQB1_OutXML1) - otb_test_application( + list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap) + if(${_classifier_has_confmap} EQUAL -1) + otb_test_application( NAME apTvClMethod${classifier}ImageClassifierQB1 APP ImageClassifier OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} @@ -172,6 +177,23 @@ foreach(classifier ${classifierList}) ${raster_ref_path}/${OUTRASTER} ${TEMP}/${OUTRASTER} ) + else() + otb_test_application( + NAME apTvClMethod${classifier}ImageClassifierQB1 + APP ImageClassifier + OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} + -model ${ascii_ref_path}/${OUTMODELFILE} + -imstat ${ascii_ref_path}/clImageStatisticsQB1${stat_input_format} + -out ${TEMP}/${OUTRASTER} ${raster_output_option} + -confmap ${TEMP}/${OUTCONFMAP} + + VALID ${raster_comparison_two} + ${raster_ref_path}/${OUTRASTER} + ${TEMP}/${OUTRASTER} + ${raster_ref_path}/${OUTCONFMAP} + ${TEMP}/${OUTCONFMAP} + ) + endif() endforeach() diff --git a/Modules/Learning/SVMLearning/include/otbSVMModel.h b/Modules/Learning/SVMLearning/include/otbSVMModel.h index 682f3d7bb6552aba3d92435da9f81e38ae7dfc3e..f64af6c8dd751048e3db01a74d539cbdbd467cd3 100644 --- a/Modules/Learning/SVMLearning/include/otbSVMModel.h +++ b/Modules/Learning/SVMLearning/include/otbSVMModel.h @@ -307,6 +307,12 @@ public: return static_cast<bool>(m_Parameters.probability); } + /** Test if the model has probabilities */ + bool HasProbabilities(void) const + { + return static_cast<bool>(svm_check_probability_model(m_Model)); + } + /** Return number of support vectors */ int GetNumberOfSupportVectors(void) const { diff --git a/Modules/Learning/SVMLearning/include/otbSVMModel.txx b/Modules/Learning/SVMLearning/include/otbSVMModel.txx index 07bc7eef9248266fc94eacfd43f5294a3bdb1d16..a527b2eca7656efd37b52ec66e7d39937cb23dd8 100644 --- a/Modules/Learning/SVMLearning/include/otbSVMModel.txx +++ b/Modules/Learning/SVMLearning/include/otbSVMModel.txx @@ -464,7 +464,7 @@ SVMModel<TValue, TLabel>::EvaluateProbabilities(const MeasurementType& measure) itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities"); } - if (svm_check_probability_model(m_Model) == 0) + if (!this->HasProbabilities()) { throw itk::ExceptionObject(__FILE__, __LINE__, "Model does not support probability estimates", ITK_LOCATION); diff --git a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h index 510717d4ea36312a89bc2c17459bef9d23af605a..185d10ba44bf26db0a9444c5eb61025ef4f2899a 100644 --- a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h @@ -40,19 +40,17 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(BoostMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(BoostMachineLearningModel, MachineLearningModel); /** Setters/Getters to the Boost type * It can be CvBoost::DISCRETE, CvBoost::REAL, CvBoost::LOGIT, CvBoost::GENTLE @@ -122,7 +120,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: BoostMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.txx index 825b119fc6d00f5c19d0c04ffb6c525521fc7640..493107f9ff2b8da6c652fc67bc31376664eed1cc 100644 --- a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.txx @@ -37,6 +37,7 @@ BoostMachineLearningModel<TInputValue,TOutputValue> m_SplitCrit(CvBoost::DEFAULT), m_MaxDepth(1) { + this->m_ConfidenceIndex = true; } @@ -76,7 +77,7 @@ template <class TInputValue, class TOutputValue> typename BoostMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType BoostMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -91,6 +92,12 @@ BoostMachineLearningModel<TInputValue,TOutputValue> target[0] = static_cast<TOutputValue>(result); + if (quality != NULL) + { + (*quality) = static_cast<ConfidenceValueType>( + m_BoostModel->predict(sample,missing,cv::Range::all(),false,true)); + } + return target; } diff --git a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h index 75fe1401bf431cb9e827c314f1187e0e06456c0f..6415fa940e273e5d9074a130ba95a8d8105a4978 100644 --- a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h @@ -40,19 +40,17 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(DecisionTreeMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(DecisionTreeMachineLearningModel, MachineLearningModel); /** Setters/Getters to the maximum depth of the tree. * The maximum possible depth of the tree. That is the training algorithms attempts to split a node @@ -183,7 +181,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: DecisionTreeMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.txx index ef98de235bd7a01114d5bd07db45bb3f3e7a9adf..ada456e94c35ff1467e97b8301a5e4a315fa35f6 100644 --- a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.txx @@ -83,7 +83,7 @@ template <class TInputValue, class TOutputValue> typename DecisionTreeMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType DecisionTreeMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -96,6 +96,14 @@ DecisionTreeMachineLearningModel<TInputValue,TOutputValue> target[0] = static_cast<TOutputValue>(result); + if (quality != NULL) + { + if (!this->m_ConfidenceIndex) + { + itkExceptionMacro("Confidence index not available for this classifier !"); + } + } + return target; } diff --git a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h index 00901e68d4b93f513c810c3e9d7fd86919e6fcf8..3126f44c0ff7073def41f025df615e74013b590b 100644 --- a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h @@ -40,19 +40,17 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(GradientBoostedTreeMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(GradientBoostedTreeMachineLearningModel, MachineLearningModel); /** Type of the loss function used for training. * It must be one of the following types: CvGBTrees::SQUARED_LOSS, CvGBTrees::ABSOLUTE_LOSS, @@ -130,7 +128,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: GradientBoostedTreeMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.txx index c7421a55b199bbe47b130d59a995508c4980746c..595f1bb060b8ffbcb533352915f41b69c2bd1458 100644 --- a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.txx @@ -80,7 +80,7 @@ template <class TInputValue, class TOutputValue> typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -93,6 +93,14 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> target[0] = static_cast<TOutputValue>(result); + if (quality != NULL) + { + if (!this->m_ConfidenceIndex) + { + itkExceptionMacro("Confidence index not available for this classifier !"); + } + } + return target; } diff --git a/Modules/Learning/Supervised/include/otbImageClassificationFilter.h b/Modules/Learning/Supervised/include/otbImageClassificationFilter.h index 2c79315b3d5ca1e127d3c7a0ead88688fe14c5e4..f6ed68aa1db6cec92aa37bdc870f0e19cbd59651 100644 --- a/Modules/Learning/Supervised/include/otbImageClassificationFilter.h +++ b/Modules/Learning/Supervised/include/otbImageClassificationFilter.h @@ -20,6 +20,7 @@ #include "itkImageToImageFilter.h" #include "otbMachineLearningModel.h" +#include "otbImage.h" namespace otb { @@ -68,6 +69,9 @@ public: typedef MachineLearningModel<ValueType, LabelType> ModelType; typedef typename ModelType::Pointer ModelPointerType; + typedef otb::Image<double> ConfidenceImageType; + typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType; + /** Set/Get the model */ itkSetObjectMacro(Model, ModelType); itkGetObjectMacro(Model, ModelType); @@ -76,6 +80,10 @@ public: itkSetMacro(DefaultLabel, LabelType); itkGetMacro(DefaultLabel, LabelType); + /** Set/Get the confidence map flag */ + itkSetMacro(UseConfidenceMap, bool); + itkGetMacro(UseConfidenceMap, bool); + /** * If set, only pixels within the mask will be classified. * All pixels with a value greater than 0 in the mask, will be classified. @@ -89,6 +97,11 @@ public: */ const MaskImageType * GetInputMask(void); + /** + * Get the output confidence map + */ + ConfidenceImageType * GetOutputConfidence(void); + protected: /** Constructor */ ImageClassificationFilter(); @@ -110,7 +123,8 @@ private: ModelPointerType m_Model; /** Default label for invalid pixels (when using a mask) */ LabelType m_DefaultLabel; - + /** Flag to produce the confidence map (if the model supports it) */ + bool m_UseConfidenceMap; }; } // End namespace otb #ifndef OTB_MANUAL_INSTANTIATION diff --git a/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx b/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx index 06225af19e37ce45feb5e6008121f8c06f0bd05a..88d2ad012441455cc2c2d221e79e175731b3e31f 100644 --- a/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx +++ b/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx @@ -34,6 +34,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> this->SetNumberOfIndexedInputs(2); this->SetNumberOfRequiredInputs(1); m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue(); + + this->SetNumberOfRequiredOutputs(2); + this->SetNthOutput(0,TOutputImage::New()); + this->SetNthOutput(1,ConfidenceImageType::New()); + m_UseConfidenceMap = false; } template <class TInputImage, class TOutputImage, class TMaskImage> @@ -57,6 +62,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1)); } +template <class TInputImage, class TOutputImage, class TMaskImage> +typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ConfidenceImageType * +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::GetOutputConfidence() +{ + if (this->GetNumberOfOutputs() < 2) + { + return 0; + } + return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1)); +} + template <class TInputImage, class TOutputImage, class TMaskImage> void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> @@ -77,6 +95,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> InputImageConstPointerType inputPtr = this->GetInput(); MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); OutputImagePointerType outputPtr = this->GetOutput(); + ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence(); // Progress reporting itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); @@ -85,6 +104,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; + typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType; InputIteratorType inIt(inputPtr, outputRegionForThread); OutputIteratorType outIt(outputPtr, outputRegionForThread); @@ -97,7 +117,17 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> maskIt.GoToBegin(); } + // setup iterator for confidence map + bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex()); + ConfidenceMapIteratorType confidenceIt; + if (computeConfidenceMap) + { + confidenceIt = ConfidenceMapIteratorType(confidencePtr,outputRegionForThread); + confidenceIt.GoToBegin(); + } + bool validPoint = true; + double confidenceIndex = 0.0; // Walk the part of the image for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt) @@ -112,12 +142,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> if (validPoint) { // Classifify - outIt.Set(m_Model->Predict(inIt.Get())[0]); + if (computeConfidenceMap) + { + outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]); + } + else + { + outIt.Set(m_Model->Predict(inIt.Get())[0]); + } } else { // else, set default value outIt.Set(m_DefaultLabel); + confidenceIndex = 0.0; + } + if (computeConfidenceMap) + { + confidenceIt.Set(confidenceIndex); + ++confidenceIt; } progress.CompletedPixel(); } diff --git a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h index d3548c9c4064a5a8388e12b35861d85ea9ba9174..b8524b7fb677aff35b706a54b87b6a917251754b 100644 --- a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h @@ -39,19 +39,17 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(KNearestNeighborsMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(KNearestNeighborsMachineLearningModel, MachineLearningModel); /** Setters/Getters to the number of neighbors to use * Default is 32 @@ -95,7 +93,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: KNearestNeighborsMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.txx index c1048f63604b87a48ef2c84fe58b6a6f347f1d9d..e890b9137ac284848eda3c10885bfd736a4c28ac 100644 --- a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.txx @@ -35,6 +35,7 @@ KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue> m_K(32), m_IsRegression(false) { + this->m_ConfidenceIndex = true; } @@ -66,18 +67,36 @@ template <class TInputValue, class TTargetValue> typename KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue> ::TargetSampleType KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; otb::SampleToMat<InputSampleType>(input, sample); - double result = m_KNearestModel->find_nearest(sample, m_K); + float result; + + if (quality != NULL) + { + cv::Mat nearest(1,m_K,CV_32FC1); + result = m_KNearestModel->find_nearest(sample, m_K,0,0,&nearest,0); + unsigned int accuracy = 0; + for (int k=0 ; k < m_K ; ++k) + { + if (nearest.at<float>(0,k) == result) + { + accuracy++; + } + } + (*quality) = static_cast<ConfidenceValueType>(accuracy); + } + else + { + result = m_KNearestModel->find_nearest(sample, m_K); + } TargetSampleType target; target[0] = static_cast<TTargetValue>(result); - return target; } diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h index 5c408173bd8a22b38f48a9818d54c50c331703bc..18eba7fbb28d336967ac669d05cda51bc5ebdc4c 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h @@ -40,15 +40,13 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; // LibSVM related typedefs typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<InputSampleType> MeasurementVectorFunctorType; @@ -59,7 +57,7 @@ public: /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(SVMMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(SVMMachineLearningModel, MachineLearningModel); /** Save the model to file */ virtual void Save(const std::string &filename, const std::string & name=""); @@ -108,7 +106,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: LibSVMMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx index 0cf15f4110ffaaa18d27443003ad0778a8de0fdc..24e91e54a25406d100a8cae8f4b09741f124e80f 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx @@ -74,13 +74,15 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> m_SVMestimator->SetTrainingSampleList(this->GetTargetListSample()); m_SVMestimator->Update(); + + this->m_ConfidenceIndex = m_DoProbabilityEstimates; } template <class TInputValue, class TOutputValue> typename LibSVMMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType LibSVMMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { TargetSampleType target; @@ -88,6 +90,31 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> target = m_SVMestimator->GetModel()->EvaluateLabel(mfunctor(input)); + if (quality != NULL) + { + if (!this->m_ConfidenceIndex) + { + itkExceptionMacro("Confidence index not available for this classifier !"); + } + typename SVMEstimatorType::ModelType::ProbabilitiesVectorType probaVector = + m_SVMestimator->GetModel()->EvaluateProbabilities(mfunctor(input)); + double maxProb = 0.0; + double secProb = 0.0; + for (unsigned int i=0 ; i<probaVector.Size() ; ++i) + { + if (maxProb < probaVector[i]) + { + secProb = maxProb; + maxProb = probaVector[i]; + } + else if (secProb < probaVector[i]) + { + secProb = probaVector[i]; + } + } + (*quality) = static_cast<ConfidenceValueType>(maxProb - secProb); + } + return target; } @@ -105,6 +132,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> ::Load(const std::string & filename, const std::string & itkNotUsed(name)) { m_SVMestimator->GetModel()->LoadModel(filename.c_str()); + + this->m_ConfidenceIndex = m_SVMestimator->GetModel()->HasProbabilities(); } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModel.h b/Modules/Learning/Supervised/include/otbMachineLearningModel.h index 7d750ed6c4047e63f65cbf578b6d7d1f2828e514..8a6d55ce07652951779fd2f0ed05c09433182fb4 100644 --- a/Modules/Learning/Supervised/include/otbMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbMachineLearningModel.h @@ -60,7 +60,7 @@ namespace otb * * \ingroup OTBSupervised */ -template <class TInputValue, class TTargetValue> +template <class TInputValue, class TTargetValue, class TConfidenceValue = double > class ITK_EXPORT MachineLearningModel : public itk::Object { @@ -87,6 +87,9 @@ public: typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; //@} + /**\name Confidence value typedef */ + typedef TConfidenceValue ConfidenceValueType; + /**\name Standard macros */ //@{ /** Run-time type information (and related methods). */ @@ -97,7 +100,7 @@ public: void Train(); /** Predict values using the model */ - TargetSampleType Predict(const InputSampleType& input) const; + TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = NULL) const; /** Classify all samples in InputListSample and fill TargetListSample with the associated label */ void PredictAll(); @@ -120,6 +123,9 @@ public: virtual bool CanWriteFile(const std::string &) = 0; //@} + /** Query capacity to produce a confidence index */ + bool HasConfidenceIndex() const {return m_ConfidenceIndex;} + /**\name Input list of samples accessors */ //@{ itkSetObjectMacro(InputListSample,InputListSampleType); @@ -163,9 +169,12 @@ protected: (void)input; } - virtual TargetSampleType PredictClassification(const InputSampleType& input) const = 0; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality = NULL) const = 0; bool m_RegressionMode; + + /** flag that tells if the model support confidence index output */ + bool m_ConfidenceIndex; private: MachineLearningModel(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbMachineLearningModel.txx index 5164fec3e2440e30f4c2a541d45cd7cadd261676..b4b6127c11c252c3d5a73d0a50716d48da6e7a49 100644 --- a/Modules/Learning/Supervised/include/otbMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbMachineLearningModel.txx @@ -23,20 +23,20 @@ namespace otb { -template <class TInputValue, class TOutputValue> -MachineLearningModel<TInputValue,TOutputValue> -::MachineLearningModel() : m_RegressionMode(false) +template <class TInputValue, class TOutputValue, class TConfidenceValue> +MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> +::MachineLearningModel() : m_RegressionMode(false),m_ConfidenceIndex(false) {} -template <class TInputValue, class TOutputValue> -MachineLearningModel<TInputValue,TOutputValue> +template <class TInputValue, class TOutputValue, class TConfidenceValue> +MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ::~MachineLearningModel() {} -template <class TInputValue, class TOutputValue> +template <class TInputValue, class TOutputValue, class TConfidenceValue> void -MachineLearningModel<TInputValue,TOutputValue> +MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ::Train() { if(m_RegressionMode) @@ -45,20 +45,20 @@ MachineLearningModel<TInputValue,TOutputValue> return this->TrainClassification(); } -template <class TInputValue, class TOutputValue> -typename MachineLearningModel<TInputValue,TOutputValue>::TargetSampleType -MachineLearningModel<TInputValue,TOutputValue> -::Predict(const typename MachineLearningModel<TInputValue,TOutputValue>::InputSampleType& input) const +template <class TInputValue, class TOutputValue, class TConfidenceValue> +typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>::TargetSampleType +MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> +::Predict(const InputSampleType& input, ConfidenceValueType *quality) const { if(m_RegressionMode) return this->PredictRegression(input); else - return this->PredictClassification(input); + return this->PredictClassification(input,quality); } -template <class TInputValue, class TOutputValue> +template <class TInputValue, class TOutputValue, class TConfidenceValue> void -MachineLearningModel<TInputValue,TOutputValue> +MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ::PredictAll() { TargetListSampleType * targets = this->GetTargetListSample(); @@ -71,9 +71,9 @@ MachineLearningModel<TInputValue,TOutputValue> } } -template <class TInputValue, class TOutputValue> +template <class TInputValue, class TOutputValue, class TConfidenceValue> void -MachineLearningModel<TInputValue,TOutputValue> +MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ::PrintSelf(std::ostream& os, itk::Indent indent) const { // Call superclass implementation diff --git a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h index 3890830682222f94b8421a0d44540a25a4cd4784..2eab72042d9ff5b3228df2d628ee0618127b01f4 100644 --- a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h @@ -40,21 +40,19 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; typedef std::map<TargetValueType, unsigned int> MapOfLabelsType; /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(NeuralNetworkMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(NeuralNetworkMachineLearningModel, MachineLearningModel); /** Setters/Getters to the train method * 2 methods are available: @@ -183,7 +181,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: NeuralNetworkMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.txx index 95b4c1502fe457bc19d68b0def8f4fd43b98ac4c..01c31811d7c793630dfcfe5f0026cc6b7be73356 100644 --- a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.txx @@ -42,6 +42,7 @@ NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::NeuralNetworkMachi m_Epsilon(0.01), m_CvMatOfLabels(0) { + this->m_ConfidenceIndex = true; } template<class TInputValue, class TOutputValue> @@ -169,7 +170,7 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassifi template<class TInputValue, class TOutputValue> typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSampleType NeuralNetworkMachineLearningModel< - TInputValue, TOutputValue>::PredictClassification(const InputSampleType & input) const + TInputValue, TOutputValue>::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -182,6 +183,7 @@ typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSam TargetSampleType target; float currentResponse = 0; float maxResponse = response.at<float> (0, 0); + float secondResponse = -1e10; target[0] = m_CvMatOfLabels->data.i[0]; unsigned int nbClasses = m_CvMatOfLabels->cols; @@ -190,9 +192,22 @@ typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSam currentResponse = response.at<float> (0, itLabel); if (currentResponse > maxResponse) { + secondResponse = maxResponse; maxResponse = currentResponse; target[0] = m_CvMatOfLabels->data.i[itLabel]; } + else + { + if (currentResponse > secondResponse) + { + secondResponse = currentResponse; + } + } + } + + if (quality != NULL) + { + (*quality) = static_cast<ConfidenceValueType>(maxResponse) - static_cast<ConfidenceValueType>(secondResponse); } return target; diff --git a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h index 7650666fb8eea7315aadc2a96975f999d89e21cb..31b5373ce16510512aa90065e151408a764898ed 100644 --- a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h @@ -40,19 +40,17 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(NormalBayesMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(NormalBayesMachineLearningModel, MachineLearningModel); /** Save the model to file */ virtual void Save(const std::string & filename, const std::string & name=""); @@ -82,7 +80,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: NormalBayesMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.txx index 1e9d8f17216531d0c6a3916ebad8ef17c706e148..bfc163553ddb58eb7e3e9f2150fc9cb63dbfaca2 100644 --- a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.txx @@ -61,7 +61,7 @@ template <class TInputValue, class TOutputValue> typename NormalBayesMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType NormalBayesMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -76,6 +76,14 @@ NormalBayesMachineLearningModel<TInputValue,TOutputValue> target[0] = static_cast<TOutputValue>(result); + if (quality != NULL) + { + if (!this->HasConfidenceIndex()) + { + itkExceptionMacro("Confidence index not available for this classifier !"); + } + } + return target; } diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index b132ffefabf05a2af1e2b3b9fbc0b8d6da23b65e..20aff52a7e200f47f9127005addc442d25fd5799 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -40,15 +40,13 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; // Other typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType; @@ -59,7 +57,7 @@ public: /** Run-time type information (and related methods). */ itkNewMacro(Self); - itkTypeMacro(RandomForestsMachineLearningModel, itk::MachineLearningModel); + itkTypeMacro(RandomForestsMachineLearningModel, MachineLearningModel); /** Save the model to file */ virtual void Save(const std::string & filename, const std::string & name=""); @@ -144,7 +142,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: RandomForestsMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx index 9c06e330a25690125366f813789c519fadaffc97..0797c2fabf18077612470b70604d8a3492a64446 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.txx @@ -42,6 +42,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS), m_RegressionMode(false) { + this->m_ConfidenceIndex = true; } @@ -109,7 +110,7 @@ template <class TInputValue, class TOutputValue> typename RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType RandomForestsMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & value) const +::PredictClassification(const InputSampleType & value, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -122,6 +123,11 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> target[0] = static_cast<TOutputValue>(result); + if (quality != NULL) + { + (*quality) = m_RFModel->predict_prob(sample); + } + return target[0]; } diff --git a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h index 4ff5ce670315e59789076d03360fc9d202ff2859..406bdd10e6b04a808fbd2be168e0f7504aa79578 100644 --- a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h @@ -39,15 +39,13 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - // Input related typedefs - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; - - // Target related typedefs - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; /** Run-time type information (and related methods). */ itkNewMacro(Self); @@ -131,7 +129,7 @@ protected: /** Train the machine learning model */ virtual void TrainClassification(); /** Predict values using the model */ - virtual TargetSampleType PredictClassification(const InputSampleType& input) const; + virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const; private: SVMMachineLearningModel(const Self &); //purposely not implemented diff --git a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.txx index beca36102ae5bb865bed47514e250bd1637d7ae6..064db060d1670146d4b618d94fa529017475758e 100644 --- a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.txx @@ -49,6 +49,7 @@ SVMMachineLearningModel<TInputValue,TOutputValue> m_OutputNu(0), m_OutputP(0) { + this->m_ConfidenceIndex = true; } @@ -108,7 +109,7 @@ template <class TInputValue, class TOutputValue> typename SVMMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType SVMMachineLearningModel<TInputValue,TOutputValue> -::PredictClassification(const InputSampleType & input) const +::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const { //convert listsample to Mat cv::Mat sample; @@ -121,6 +122,11 @@ SVMMachineLearningModel<TInputValue,TOutputValue> target[0] = static_cast<TOutputValue>(result); + if (quality != NULL) + { + (*quality) = m_SVMModel->predict(sample,true); + } + return target; }