diff --git a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx index e05dcc1964584268ebe1b6318b907cd6188462db..49bae7fd69dcbb466437cef0c0f86c2bc1b9f4bf 100644 --- a/Modules/Applications/AppClassification/app/otbImageClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbImageClassifier.cxx @@ -64,6 +64,7 @@ public: typedef ClassificationFilterType::LabelType LabelType; typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; typedef ClassificationFilterType::ConfidenceImageType ConfidenceImageType; + typedef ClassificationFilterType::ProbaImageType ProbaImageType; protected: @@ -111,6 +112,7 @@ private: SetDefaultParameterInt("nodatalabel", 0); MandatoryOff("nodatalabel"); + AddParameter(ParameterType_OutputImage, "out", "Output Image"); SetParameterDescription( "out", "Output image containing class labels"); SetDefaultOutputPixelType( "out", ImagePixelType_uint8); @@ -129,8 +131,16 @@ private: SetDefaultOutputPixelType( "confmap", ImagePixelType_double); MandatoryOff("confmap"); + AddParameter(ParameterType_OutputImage,"probamap", "Probability map"); + SetParameterDescription("probamap","Probability of each class for each pixel. This is an image having a number of bands equal to the number of classes in the model. This is only implemented for the Shark Random Forest classifier at this point."); + SetDefaultOutputPixelType("probamap",ImagePixelType_uint16); + MandatoryOff("probamap"); AddRAMParameter(); + AddParameter(ParameterType_Int, "nbclasses", "Number of classes in the model"); + SetDefaultParameterInt("nbclasses", 20); + SetParameterDescription("nbclasses","The number of classes is needed for the probamap output in order to set the number of output bands."); + // Doc example parameter settings SetDocExampleParameterValue("in", "QB_1_ortho.tif"); SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml"); @@ -173,7 +183,7 @@ private: // Classify m_ClassificationFilter = ClassificationFilterType::New(); m_ClassificationFilter->SetModel(m_Model); - + m_ClassificationFilter->SetDefaultLabel(GetParameterInt("nodatalabel")); // Normalize input image if asked @@ -208,9 +218,9 @@ private: m_ClassificationFilter->SetInputMask(inMask); } - SetParameterOutputImage<OutputImageType>("out", m_ClassificationFilter->GetOutput()); - + + // output confidence map if (IsParameterEnabled("confmap") && HasValue("confmap")) { @@ -225,6 +235,21 @@ private: this->DisableParameter("confmap"); } } + if(IsParameterEnabled("probamap") && HasValue("probamap")) + { + m_ClassificationFilter->SetUseProbaMap(true); + if(m_Model->HasProbaIndex()) + { + m_ClassificationFilter->SetNumberOfClasses(GetParameterInt("nbclasses")); + SetParameterOutputImage<ProbaImageType>("probamap",m_ClassificationFilter->GetOutputProba()); + } + else + { + otbAppLogWARNING("Probability map requested but the classifier doesn't support it!"); + this->DisableParameter("probamap"); + } + } + } ClassificationFilterType::Pointer m_ClassificationFilter; diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 46f4e82dcf3142883cd94d7508458647dd02c47a..13db483c00d9fe033a0b7b7f531abdb1df5a6fe7 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -115,6 +115,7 @@ set(sharkkm_parameters "") set(ascii_comparison --compare-ascii ${EPSILON_6}) set(raster_comparison --compare-image ${NOTOL}) set(raster_comparison_two --compare-n-images ${NOTOL} 2) +set(raster_comparison_three --compare-n-images ${NOTOL} 3) # Reference ffiles depending on modes set(ascii_ref_path ${OTBAPP_BASELINE_FILES}) @@ -136,6 +137,7 @@ if(OTB_USE_SHARK) endif() set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF") +set(classifier_with_probamap "SHARKRF") # This is a black list for classifier that can not have a baseline # because they are using randomness and seed can not be set @@ -156,6 +158,7 @@ foreach(classifier ${classifierList}) set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format}) set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format}) + set(OUTPROBAMAP cl${classifier}ProbabilityMapQB1${raster_output_format}) list(FIND classifier_without_baseline ${classifier} _classifier_has_baseline) if(${_classifier_has_baseline} EQUAL -1) @@ -198,6 +201,7 @@ foreach(classifier ${classifierList}) set_tests_properties(apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 PROPERTIES DEPENDS apTvClTrainMethod${classifier}ImagesClassifierQB1) list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap) + list(FIND classifier_with_probamap ${classifier} _classifier_has_probamap) if(${_classifier_has_confmap} EQUAL -1) otb_test_application( NAME apTvClMethod${classifier}ImageClassifierQB1 @@ -212,21 +216,44 @@ foreach(classifier ${classifierList}) ${TEMP}/${OUTRASTER} ) else() - otb_test_application( - NAME apTvClMethod${classifier}ImageClassifierQB1 - APP ImageClassifier - OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} - -model ${INPUTDATA}/Classification/${OUTMODELFILE} - -imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} - -out ${TEMP}/${OUTRASTER} ${raster_output_option} - -confmap ${TEMP}/${OUTCONFMAP} - - VALID ${raster_comparison_two} - ${raster_ref_path}/${OUTRASTER} - ${TEMP}/${OUTRASTER} - ${raster_ref_path}/${OUTCONFMAP} - ${TEMP}/${OUTCONFMAP} - ) + if(${_classifier_has_probamap} EQUAL -1) + otb_test_application( + NAME apTvClMethod${classifier}ImageClassifierQB1 + APP ImageClassifier + OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} + -model ${INPUTDATA}/Classification/${OUTMODELFILE} + -imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} + -out ${TEMP}/${OUTRASTER} ${raster_output_option} + -confmap ${TEMP}/${OUTCONFMAP} + + VALID ${raster_comparison_two} + ${raster_ref_path}/${OUTRASTER} + ${TEMP}/${OUTRASTER} + ${raster_ref_path}/${OUTCONFMAP} + ${TEMP}/${OUTCONFMAP} + ) + else() + message(${classifier}) + otb_test_application( + NAME apTvClMethod${classifier}ImageClassifierQB1 + APP ImageClassifier + OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} + -model ${INPUTDATA}/Classification/${OUTMODELFILE} + -imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} + -out ${TEMP}/${OUTRASTER} ${raster_output_option} + -confmap ${TEMP}/${OUTCONFMAP} + -nbclasses 4 + -probamap ${TEMP}/${OUTPROBAMAP} + + VALID ${raster_comparison_three} + ${raster_ref_path}/${OUTRASTER} + ${TEMP}/${OUTRASTER} + ${raster_ref_path}/${OUTCONFMAP} + ${TEMP}/${OUTCONFMAP} + ${raster_ref_path}/${OUTPROBAMAP} + ${TEMP}/${OUTPROBAMAP} + ) + endif() endif() endforeach() diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.h b/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.h index 71ed7482073a8090e7ef854436707cca60a88528..1c133864de36de022b2b1a1104cdb22cef7de949 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.h +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.h @@ -84,6 +84,8 @@ public: typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + typedef typename Superclass::ProbaSampleType ProbaSampleType; + typedef typename Superclass::ProbaListSampleType ProbaListSampleType; /// Neural network related typedefs typedef shark::ConcatenatedModel<shark::RealVector> ModelType; typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType; @@ -162,14 +164,16 @@ protected: virtual TargetSampleType DoPredict( const InputSampleType& input, - ConfidenceValueType * quality = nullptr) const override; + ConfidenceValueType * quality = nullptr, + ProbaSampleType * proba = nullptr) const override; virtual void DoPredictBatch( const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, - ConfidenceListSampleType * quality = nullptr) const override; + ConfidenceListSampleType * quality = nullptr, + ProbaListSampleType * proba = nullptr) const override; private: /** Internal Network */ diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.hxx b/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.hxx index 636f62cfd22ec339da109a2a7a7fea75832481bf..695f7e29e97502d96bd5a8cee8440913941fb1cb 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.hxx +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.hxx @@ -375,7 +375,7 @@ AutoencoderModel<TInputValue,NeuronType> template <class TInputValue, class NeuronType> typename AutoencoderModel<TInputValue,NeuronType>::TargetSampleType AutoencoderModel<TInputValue,NeuronType> -::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const +::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const { shark::RealVector samples(value.Size()); for(size_t i = 0; i < value.Size();i++) @@ -408,7 +408,8 @@ AutoencoderModel<TInputValue,NeuronType> const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, - ConfidenceListSampleType * /*quality*/) const + ConfidenceListSampleType * /*quality*/, + ProbaListSampleType * /*proba*/) const { std::vector<shark::RealVector> features; Shark::ListSampleRangeToSharkVector(input, features,startIndex,size); diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.h b/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.h index c1aa97b938b350824da50b0d2886537fed3828f4..cc3971fe001e6cda8f8c84158e2b82ecad546394 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.h +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.h @@ -79,6 +79,9 @@ public: typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + typedef typename Superclass::ProbaSampleType ProbaSampleType; + typedef typename Superclass::ProbaListSampleType ProbaListSampleType; + itkNewMacro(Self); itkTypeMacro(PCAModel, DimensionalityReductionModel); @@ -101,14 +104,16 @@ protected: virtual TargetSampleType DoPredict( const InputSampleType& input, - ConfidenceValueType * quality = nullptr) const override; + ConfidenceValueType * quality = nullptr, + ProbaSampleType * proba = nullptr) const override; virtual void DoPredictBatch( const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, - ConfidenceListSampleType * quality = nullptr) const override; + ConfidenceListSampleType * quality = nullptr, + ProbaListSampleType* proba = nullptr) const override; private: shark::LinearModel<> m_Encoder; diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.hxx b/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.hxx index 85601662729a23d931b2b5d841494b7633b8728a..9c6973367242ec77ecad2240a215dabaee36369f 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.hxx +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbPCAModel.hxx @@ -148,7 +148,7 @@ PCAModel<TInputValue>::Load(const std::string & filename, const std::string & /* template <class TInputValue> typename PCAModel<TInputValue>::TargetSampleType -PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const +PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const { shark::RealVector samples(value.Size()); for(size_t i = 0; i < value.Size();i++) @@ -173,7 +173,7 @@ PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueT template <class TInputValue> void PCAModel<TInputValue> -::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/) const +::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/,ProbaListSampleType * /*proba*/) const { std::vector<shark::RealVector> features; Shark::ListSampleRangeToSharkVector(input, features,startIndex,size); diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.h b/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.h index 11c5fe592a7b56ca04a6ca81c55edfdfb4ec0f7e..304fdac23d293d8e7d3275ab9f76e9e921875955 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.h +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.h @@ -64,7 +64,8 @@ public: typedef typename Superclass::ConfidenceValueType ConfidenceValueType; typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; + typedef typename Superclass::ProbaListSampleType ProbaListSampleType; typedef SOMMap< itk::VariableLengthVector<TInputValue>, itk::Statistics::EuclideanDistanceMetric< @@ -118,7 +119,8 @@ private: virtual TargetSampleType DoPredict( const InputSampleType& input, - ConfidenceValueType * quality = nullptr) const override; + ConfidenceValueType * quality = nullptr, + ProbaSampleType * proba = nullptr) const override; /** Map size (width, height) */ SizeType m_MapSize; diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.hxx b/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.hxx index 2c2abca261bb0533745eaf55fad822f9d9397f03..a2f93184f4ae05f42fe8205355ec0863121a9ab7 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.hxx +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbSOMModel.hxx @@ -212,7 +212,8 @@ template <class TInputValue, unsigned int MapDimension> typename SOMModel<TInputValue, MapDimension>::TargetSampleType SOMModel<TInputValue, MapDimension>::DoPredict( const InputSampleType & value, - ConfidenceValueType * /*quality*/) const + ConfidenceValueType * /*quality*/, + ProbaSampleType * /*proba*/) const { TargetSampleType target; target.SetSize(this->m_Dimension); diff --git a/Modules/Learning/LearningBase/include/otbImageClassificationFilter.h b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.h index 4609ce047de79e67fcb7a5e09a1174579ba172a0..2ff458e7d4ff5935631b676d00d4142a2756d76b 100644 --- a/Modules/Learning/LearningBase/include/otbImageClassificationFilter.h +++ b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.h @@ -24,6 +24,7 @@ #include "itkImageToImageFilter.h" #include "otbMachineLearningModel.h" #include "otbImage.h" +#include "otbVectorImage.h" namespace otb { @@ -75,6 +76,11 @@ public: typedef otb::Image<double> ConfidenceImageType; typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType; + /**Output type for Proba */ + typedef otb::VectorImage<double> ProbaImageType; + + typedef typename ProbaImageType::Pointer ProbaImagePointerType; + typedef itk::VariableLengthVector<double> ProbaSampleType; /** Set/Get the model */ itkSetObjectMacro(Model, ModelType); itkGetObjectMacro(Model, ModelType); @@ -87,10 +93,16 @@ public: itkSetMacro(UseConfidenceMap, bool); itkGetMacro(UseConfidenceMap, bool); + /** Set/Get the proba map flag */ + itkSetMacro(UseProbaMap, bool); + itkGetMacro(UseProbaMap, bool); + itkSetMacro(BatchMode, bool); itkGetMacro(BatchMode, bool); itkBooleanMacro(BatchMode); + itkSetMacro(NumberOfClasses, unsigned int); + itkGetMacro(NumberOfClasses, unsigned int); /** * If set, only pixels within the mask will be classified. * All pixels with a value greater than 0 in the mask, will be classified. @@ -108,7 +120,7 @@ public: * Get the output confidence map */ ConfidenceImageType * GetOutputConfidence(void); - + ProbaImageType * GetOutputProba(void); protected: /** Constructor */ ImageClassificationFilter(); @@ -124,6 +136,12 @@ protected: /**PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; + void GenerateOutputInformation() override + { + Superclass::GenerateOutputInformation(); + // Define the number of output bands + this->GetOutputProba()->SetNumberOfComponentsPerPixel(m_NumberOfClasses); + } private: ImageClassificationFilter(const Self &) = delete; void operator =(const Self&) = delete; @@ -134,7 +152,9 @@ private: LabelType m_DefaultLabel; /** Flag to produce the confidence map (if the model supports it) */ bool m_UseConfidenceMap; + bool m_UseProbaMap; bool m_BatchMode; + unsigned int m_NumberOfClasses; }; } // End namespace otb #ifndef OTB_MANUAL_INSTANTIATION diff --git a/Modules/Learning/LearningBase/include/otbImageClassificationFilter.hxx b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.hxx index fda4ab00349481e59c459257999737dd6b744b27..4aea60a2dade1cfa0a281d19573e8a6084da0918 100644 --- a/Modules/Learning/LearningBase/include/otbImageClassificationFilter.hxx +++ b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.hxx @@ -38,11 +38,14 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> this->SetNumberOfRequiredInputs(1); m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue(); - this->SetNumberOfRequiredOutputs(2); + this->SetNumberOfRequiredOutputs(3); this->SetNthOutput(0,TOutputImage::New()); this->SetNthOutput(1,ConfidenceImageType::New()); + this->SetNthOutput(2,ProbaImageType::New()); m_UseConfidenceMap = false; + m_UseProbaMap = false; m_BatchMode = true; + m_NumberOfClasses = 1; } template <class TInputImage, class TOutputImage, class TMaskImage> @@ -79,6 +82,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1)); } +template <class TInputImage, class TOutputImage, class TMaskImage> +typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ProbaImageType * +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::GetOutputProba() +{ + if (this->GetNumberOfOutputs() < 2) + { + return nullptr; + } + return static_cast<ProbaImageType *>(this->itk::ProcessObject::GetOutput(2)); +} + template <class TInputImage, class TOutputImage, class TMaskImage> void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> @@ -100,22 +116,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> template <class TInputImage, class TOutputImage, class TMaskImage> void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> -::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId) +::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, + itk::ThreadIdType threadId) { // Get the input pointers InputImageConstPointerType inputPtr = this->GetInput(); MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); OutputImagePointerType outputPtr = this->GetOutput(); ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence(); - + ProbaImagePointerType probaPtr = this->GetOutputProba(); // Progress reporting - itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); + itk::ProgressReporter progress(this, threadId, + outputRegionForThread.GetNumberOfPixels()); // Define iterators typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType; + typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType; InputIteratorType inIt(inputPtr, outputRegionForThread); OutputIteratorType outIt(outputPtr, outputRegionForThread); @@ -129,7 +148,8 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> } // setup iterator for confidence map - bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode()); + bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && + !m_Model->GetRegressionMode()); ConfidenceMapIteratorType confidenceIt; if (computeConfidenceMap) { @@ -137,11 +157,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> confidenceIt.GoToBegin(); } + // setup iterator for proba map + bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && + !m_Model->GetRegressionMode()); + + ProbaMapIteratorType probaIt; + + if(computeProbaMap) + { + probaIt = ProbaMapIteratorType(probaPtr,outputRegionForThread); + probaIt.GoToBegin(); + } + bool validPoint = true; double confidenceIndex = 0.0; - + ProbaSampleType probaVector{m_NumberOfClasses}; + probaVector.Fill(0); // Walk the part of the image - for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt) + for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); + ++inIt, ++outIt) { // Check pixel validity if (inputMaskPtr) @@ -151,17 +185,22 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> } // If point is valid if (validPoint) - { + { // Classifify - if (computeConfidenceMap) - { + if (computeProbaMap) + { + outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex, + &probaVector)[0]); + } + else if (computeConfidenceMap) + { outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]); - } + } else - { + { outIt.Set(m_Model->Predict(inIt.Get())[0]); - } } + } else { // else, set default value @@ -173,6 +212,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> confidenceIt.Set(confidenceIndex); ++confidenceIt; } + if (computeProbaMap) + { + probaIt.Set(probaVector); + ++probaIt; + } progress.CompletedPixel(); } @@ -181,24 +225,31 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> template <class TInputImage, class TOutputImage, class TMaskImage> void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> -::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId) +::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, + itk::ThreadIdType threadId) { - bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() + bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode()); + + bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() + && !m_Model->GetRegressionMode()); // Get the input pointers InputImageConstPointerType inputPtr = this->GetInput(); MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); OutputImagePointerType outputPtr = this->GetOutput(); ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence(); - + ProbaImagePointerType probaPtr = this->GetOutputProba(); + // Progress reporting - itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); + itk::ProgressReporter progress(this, threadId, + outputRegionForThread.GetNumberOfPixels()); // Define iterators typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType; + typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType; InputIteratorType inIt(inputPtr, outputRegionForThread); OutputIteratorType outIt(outputPtr, outputRegionForThread); @@ -210,16 +261,12 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> maskIt.GoToBegin(); } - // typedef typename ModelType::InputValueType InputValueType; typedef typename ModelType::InputSampleType InputSampleType; typedef typename ModelType::InputListSampleType InputListSampleType; typedef typename ModelType::TargetValueType TargetValueType; - // typedef typename ModelType::TargetSampleType TargetSampleType; typedef typename ModelType::TargetListSampleType TargetListSampleType; - // typedef typename ModelType::ConfidenceValueType ConfidenceValueType; - // typedef typename ModelType::ConfidenceSampleType ConfidenceSampleType; typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType; - + typedef typename ModelType::ProbaListSampleType ProbaListSampleType; typename InputListSampleType::Pointer samples = InputListSampleType::New(); unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel(); samples->SetMeasurementVectorSize(num_features); @@ -247,11 +294,14 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> //Make the batch prediction typename TargetListSampleType::Pointer labels; typename ConfidenceListSampleType::Pointer confidences; + typename ProbaListSampleType::Pointer probas; if(computeConfidenceMap) confidences = ConfidenceListSampleType::New(); + if(computeProbaMap) + probas = ProbaListSampleType::New(); // This call is threadsafe - labels = m_Model->PredictBatch(samples,confidences); + labels = m_Model->PredictBatch(samples,confidences,probas); // Set the output values ConfidenceMapIteratorType confidenceIt; @@ -261,12 +311,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> confidenceIt.GoToBegin(); } + ProbaMapIteratorType probaIt; + if (computeProbaMap) + { + probaIt = ProbaMapIteratorType(probaPtr,outputRegionForThread); + probaIt.GoToBegin(); + } typename TargetListSampleType::ConstIterator labIt = labels->Begin(); maskIt.GoToBegin(); for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt) { double confidenceIndex = 0.0; TargetValueType labelValue(m_DefaultLabel); + ProbaSampleType probaValues{m_NumberOfClasses}; if (inputMaskPtr) { validPoint = maskIt.Get() > 0; @@ -278,16 +335,26 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> if(computeConfidenceMap) { - confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0]; + confidenceIndex = + confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0]; } - - ++labIt; + if(computeProbaMap) + { + //The probas may have different size than the m_NumberOfClasses set by the user + auto tempProbaValues = probas->GetMeasurementVector(labIt.GetInstanceIdentifier()); + for(unsigned int i=0; i<m_NumberOfClasses; ++i) + { + if(i<tempProbaValues.Size()) probaValues[i] = tempProbaValues[i]; + else probaValues[i] = 0; + } + } + ++labIt; } else { labelValue = m_DefaultLabel; } - + outIt.Set(labelValue); if(computeConfidenceMap) @@ -295,14 +362,20 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> confidenceIt.Set(confidenceIndex); ++confidenceIt; } - + if(computeProbaMap) + { + probaIt.Set(probaValues); + ++probaIt; + } progress.CompletedPixel(); } } template <class TInputImage, class TOutputImage, class TMaskImage> void -ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> -::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId) +ImageClassificationFilter<TInputImage, TOutputImage, + TMaskImage>::ThreadedGenerateData( + const OutputImageRegionType& outputRegionForThread, + itk::ThreadIdType threadId) { if(m_BatchMode) { diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h index 01f522a4feccf6afbc246c7f61f5b14aa0e0cea5..9eb583e08477c14b30d7348256e578cff6aa4643 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h @@ -99,6 +99,9 @@ public: typedef typename MLMTargetTraits<TConfidenceValue>::SampleType ConfidenceSampleType; typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType; + + typedef itk::VariableLengthVector<double> ProbaSampleType; + typedef itk::Statistics::ListSample<ProbaSampleType> ProbaListSampleType; /**\name Standard macros */ //@{ /** Run-time type information (and related methods). */ @@ -114,7 +117,7 @@ public: * quality value, or NULL * \return The predicted label */ - TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr) const; + TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr, ProbaSampleType *proba = nullptr) const; /**\name Set and get the dimension of the model for dimensionality reduction models */ //@{ @@ -130,7 +133,7 @@ public: * Note that this method will be multi-threaded if OTB is built * with OpenMP. */ - typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr) const; + typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const; /**\name Classification model file manipulation */ //@{ @@ -152,6 +155,8 @@ public: /** Query capacity to produce a confidence index */ bool HasConfidenceIndex() const {return m_ConfidenceIndex;} + /** Query capacity to produce probability values */ + bool HasProbaIndex() const {return m_ProbaIndex;} /**\name Input list of samples accessors */ //@{ @@ -208,6 +213,9 @@ protected: /** flag that tells if the model support confidence index output */ bool m_ConfidenceIndex; + /** flag that tells if the model support probability output */ + bool m_ProbaIndex; + /** Is DoPredictBatch multi-threaded ? */ bool m_IsDoPredictBatchMultiThreaded; @@ -230,7 +238,7 @@ private: * Also set m_IsDoPredictBatchMultiThreaded to true if internal * implementation allows for parallel batch prediction. */ - virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr) const; + virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const; /** Actual implementation of single sample prediction * \param input sample to predict @@ -238,7 +246,7 @@ private: * or NULL * \return The predicted label */ - virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr) const = 0; + virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr, ProbaSampleType *proba=nullptr) const = 0; MachineLearningModel(const Self &) = delete; void operator =(const Self&) = delete; diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModel.hxx b/Modules/Learning/LearningBase/include/otbMachineLearningModel.hxx index a2c353a89a90865d2510eeafba9a75de2c18ee3d..c2956a76627b21c21540bdb45259e50efcbd9182 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModel.hxx +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModel.hxx @@ -38,6 +38,7 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> m_RegressionMode(false), m_IsRegressionSupported(false), m_ConfidenceIndex(false), + m_ProbaIndex(false), m_IsDoPredictBatchMultiThreaded(false), m_Dimension(0) {} @@ -68,10 +69,10 @@ template <class TInputValue, class TOutputValue, class TConfidenceValue> typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ::TargetSampleType MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> -::Predict(const InputSampleType& input, ConfidenceValueType *quality) const +::Predict(const InputSampleType& input, ConfidenceValueType *quality, ProbaSampleType *proba) const { // Call protected specialization entry point - return this->DoPredict(input,quality); + return this->DoPredict(input,quality,proba); } @@ -79,8 +80,9 @@ template <class TInputValue, class TOutputValue, class TConfidenceValue> typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ::TargetListSampleType::Pointer MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> -::PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality) const +::PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality, ProbaListSampleType *proba) const { + //std::cout << "Enter batch predict" << std::endl; typename TargetListSampleType::Pointer targets = TargetListSampleType::New(); targets->Resize(input->Size()); @@ -89,16 +91,19 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> quality->Clear(); quality->Resize(input->Size()); } - + if(proba!=ITK_NULLPTR) + { + proba->Clear(); + proba->Resize(input->Size()); + } if(m_IsDoPredictBatchMultiThreaded) { // Simply calls DoPredictBatch - this->DoPredictBatch(input,0,input->Size(),targets,quality); - return targets; + this->DoPredictBatch(input,0,input->Size(),targets,quality,proba); + return targets; } else { - #ifdef _OPENMP // OpenMP threading here unsigned int nb_threads(0), threadId(0), nb_batches(0); @@ -120,11 +125,11 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> batch_size+=input->Size()%nb_batches; } - this->DoPredictBatch(input,batch_start,batch_size,targets,quality); + this->DoPredictBatch(input,batch_start,batch_size,targets,quality,proba); } } #else - this->DoPredictBatch(input,0,input->Size(),targets,quality); + this->DoPredictBatch(input,0,input->Size(),targets,quality,proba); #endif return targets; } @@ -135,29 +140,42 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> template <class TInputValue, class TOutputValue, class TConfidenceValue> void MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> -::DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const + ::DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality, ProbaListSampleType * proba) const { assert(input != nullptr); assert(targets != nullptr); assert(input->Size()==targets->Size()&&"Input sample list and target label list do not have the same size."); assert(((quality==nullptr)||(quality->Size()==input->Size()))&&"Quality samples list is not null and does not have the same size as input samples list"); + assert((proba==nullptr)||(input->Size()==proba->Size())&&"Proba sample list and target label list do not have the same size."); if(startIndex+size>input->Size()) { itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"["); } - if(quality != nullptr) + if (proba != nullptr) { for(unsigned int id = startIndex;id<startIndex+size;++id) { + ProbaSampleType prob; ConfidenceValueType confidence = 0; - const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence); + const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence, &prob); quality->SetMeasurementVector(id,confidence); + proba->SetMeasurementVector(id,prob); targets->SetMeasurementVector(id,target); } } + else if(quality != ITK_NULLPTR) + { + for(unsigned int id = startIndex;id<startIndex+size;++id) + { + ConfidenceValueType confidence = 0; + const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence); + quality->SetMeasurementVector(id,confidence); + targets->SetMeasurementVector(id,target); + } + } else { for(unsigned int id = startIndex;id<startIndex+size;++id) @@ -176,6 +194,6 @@ void // Call superclass implementation Superclass::PrintSelf(os,indent); } - } +} #endif diff --git a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h index 6b61f1f03201d12caf6e0ba557129017de0fc32f..93fa4da45cd239d9a9f7deae5bbf942e40f53461 100644 --- a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.h @@ -53,7 +53,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(BoostMachineLearningModel, MachineLearningModel); @@ -124,8 +124,7 @@ protected: ~BoostMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.hxx index 8112d51d6f338bb376d892507274cb1cc4ff847d..86421b3b4b01bc1fc2313d55d8a01b1debb27f21 100644 --- a/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbBoostMachineLearningModel.hxx @@ -104,7 +104,7 @@ template <class TInputValue, class TOutputValue> typename BoostMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType BoostMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; @@ -132,6 +132,8 @@ BoostMachineLearningModel<TInputValue,TOutputValue> #endif ); } + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); target[0] = static_cast<TOutputValue>(result); return target; diff --git a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h index 97d881371c0768bc1fb5b1584de8c79fe75d95f7..7d8c4ad6cb60570da4ff07759daa5b0112aacd27 100644 --- a/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h +++ b/Modules/Learning/Supervised/include/otbCvRTreesWrapper.h @@ -65,7 +65,7 @@ public: float predict_margin(const cv::Mat& sample, const cv::Mat& missing = cv::Mat()) const; - + #ifdef OTB_OPENCV_3 #define OTB_CV_WRAP_PROPERTY(type,name) \ diff --git a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h index b0940ce737a6ff649319537508f6d7b7334b1888..b90982f2b98ab0ae5d16d7c1976c7f2e93a7ef9d 100644 --- a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.h @@ -53,7 +53,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(DecisionTreeMachineLearningModel, MachineLearningModel); @@ -179,7 +179,7 @@ protected: ~DecisionTreeMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.hxx index 52255ac51e365c359177e5d8be2c4ebed4fef38c..b5a7b618eb6ebad88010dc642196bbb18e2835e2 100644 --- a/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbDecisionTreeMachineLearningModel.hxx @@ -117,7 +117,7 @@ template <class TInputValue, class TOutputValue> typename DecisionTreeMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType DecisionTreeMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; @@ -140,8 +140,10 @@ DecisionTreeMachineLearningModel<TInputValue,TOutputValue> itkExceptionMacro("Confidence index not available for this classifier !"); } } + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); - return target; +return target; } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h index 3249ef81a9e2e22a8139db8b419c349c968b2155..15a8df438697b6bf5031eb73e1f3e4852f57af4d 100644 --- a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h @@ -51,7 +51,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(GradientBoostedTreeMachineLearningModel, MachineLearningModel); @@ -130,8 +130,7 @@ protected: ~GradientBoostedTreeMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx index b40394593bddb97d055b420b311f0cd34fd1dda6..132dd7bc3d12530260a2cc8f0ec2328775a94cbf 100644 --- a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx @@ -83,7 +83,7 @@ template <class TInputValue, class TOutputValue> typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { //convert listsample to Mat cv::Mat sample; @@ -103,6 +103,8 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> itkExceptionMacro("Confidence index not available for this classifier !"); } } + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); return target; } diff --git a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h index fff1dd22567ed636752218d94d9c70cf992184ab..b9bceffb1c5c9f6a876a7047ce1d0dadec22f64d 100644 --- a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.h @@ -53,7 +53,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(KNearestNeighborsMachineLearningModel, MachineLearningModel); @@ -104,8 +104,7 @@ protected: ~KNearestNeighborsMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.hxx index 230bd44357d3e52dd82a6598dd1893a3abd901ac..2d12cceb360808be01a6a2579b7b6ed58422d0bd 100644 --- a/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbKNearestNeighborsMachineLearningModel.hxx @@ -106,7 +106,7 @@ template <class TInputValue, class TTargetValue> typename KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue> ::TargetSampleType KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; @@ -135,6 +135,8 @@ KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue> } (*quality) = static_cast<ConfidenceValueType>(accuracy); } + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); // Decision rule : // VOTING is OpenCV default behaviour for classification diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h index 1600b022bc545ace4446d523d865e2b376f79b19..da96ae0f47dc71c7648be456f39358ff119377e4 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h @@ -47,7 +47,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** enum to choose the way confidence is computed * CM_INDEX : compute the difference between highest and second highest probability * CM_PROBA : returns probabilities for all classes @@ -272,7 +272,7 @@ protected: ~LibSVMMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.hxx index 5588ccbe7aa87ea58510da4a4dcbd423a1fd1f77..8435fecce363c43093cf3f37f9c3a1eaf48e80dc 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.hxx @@ -107,7 +107,7 @@ template <class TInputValue, class TOutputValue> typename LibSVMMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType LibSVMMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; target.Fill(0); @@ -129,6 +129,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> // terminate node x[input.Size()].index = -1; x[input.Size()].value = 0; + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); if (quality != nullptr) { diff --git a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h index 283b5ce015c53c3c9ac4e8d6c9d91d3c8668e2a0..f461a9e7d4d5f941706c0c234349cdf8881714ec 100644 --- a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.h @@ -48,7 +48,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; typedef std::map<TargetValueType, unsigned int> MapOfLabelsType; /** Run-time type information (and related methods). */ @@ -178,7 +178,7 @@ protected: ~NeuralNetworkMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; void LabelsToMat(const TargetListSampleType * listSample, cv::Mat & output); diff --git a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.hxx index 2f807898aa660c52e88ac366808badeb12f6e667..1a54679b29e30bf567b37174120aebb294d76dd8 100644 --- a/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbNeuralNetworkMachineLearningModel.hxx @@ -234,7 +234,7 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::Train() template<class TInputValue, class TOutputValue> typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSampleType NeuralNetworkMachineLearningModel< - TInputValue, TOutputValue>::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const + TInputValue, TOutputValue>::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; //convert listsample to Mat @@ -282,6 +282,9 @@ typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSam { (*quality) = static_cast<ConfidenceValueType>(maxResponse) - static_cast<ConfidenceValueType>(secondResponse); } + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); + return target; } diff --git a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h index cb3da6e3025e279b47794c8ba48dca748c832e4d..6589f6c4c174f8afe16dce22194ecb0e23280a11 100644 --- a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.h @@ -53,7 +53,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(NormalBayesMachineLearningModel, MachineLearningModel); @@ -84,8 +84,7 @@ protected: ~NormalBayesMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.hxx index 9e0ca05c8bec54e4a80765eea447ed37ba9a8768..1285d211e1c0a906c68d1a38acf1184fe91d8c0f 100644 --- a/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbNormalBayesMachineLearningModel.hxx @@ -85,7 +85,7 @@ template <class TInputValue, class TOutputValue> typename NormalBayesMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType NormalBayesMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; @@ -107,6 +107,9 @@ NormalBayesMachineLearningModel<TInputValue,TOutputValue> itkExceptionMacro("Confidence index not available for this classifier !"); } } + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); + return target; } diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h index 276e9ca3492d21db03d2872b38e4c80cc44002b9..adbade7db5d62c6d3aea1faf809461a01c5e751b 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.h @@ -50,7 +50,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; // Other typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType; @@ -137,8 +137,7 @@ protected: ~RandomForestsMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx index 10508ead57bb89d90b76cc5675731ce864e7786b..a78f668fcbff653f4016f5cdb7a98fa28760951d 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx @@ -50,6 +50,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> m_ComputeMargin(false) { this->m_ConfidenceIndex = true; + this->m_ProbaIndex = false; this->m_IsRegressionSupported = true; } @@ -171,8 +172,9 @@ template <class TInputValue, class TOutputValue> typename RandomForestsMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType RandomForestsMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & value, ConfidenceValueType *quality, ProbaSampleType *proba) const { + //std::cout << "Enter predict" << std::endl; TargetSampleType target; //convert listsample to Mat cv::Mat sample; @@ -190,6 +192,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> else (*quality) = m_RFModel->predict_confidence(sample); } + + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); + return target[0]; } diff --git a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h index dca029a2c8cf2b2a09cea9eb3162a100e7cbbe64..6fcc923971388bcd04c13f21ee89560296d023f4 100644 --- a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.h @@ -60,7 +60,7 @@ public: typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(SVMMachineLearningModel, MachineLearningModel); @@ -141,8 +141,7 @@ protected: ~SVMMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.hxx index a52163ddee88a537e23fb2be5bb1244f34cee585..dc8d55dcddc2d04981b55fac4eabfc05418b0559 100644 --- a/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbSVMMachineLearningModel.hxx @@ -174,7 +174,7 @@ template <class TInputValue, class TOutputValue> typename SVMMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType SVMMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const { TargetSampleType target; //convert listsample to Mat @@ -198,7 +198,10 @@ SVMMachineLearningModel<TInputValue,TOutputValue> (*quality) = m_SVMModel->predict(sample,true); #endif } - return target; + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); + +return target; } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h index 836afb8411431d4633d2c0b997391ac477866c37..b80fc244f0448ca0729cfcd0ed491ab0c107b9cb 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h @@ -86,7 +86,8 @@ public: typedef typename Superclass::ConfidenceValueType ConfidenceValueType; typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; - + typedef typename Superclass::ProbaSampleType ProbaSampleType; + typedef typename Superclass::ProbaListSampleType ProbaListSampleType; /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(SharkRandomForestsMachineLearningModel, MachineLearningModel); @@ -155,10 +156,9 @@ protected: virtual ~SharkRandomForestsMachineLearningModel(); /** Predict values using the model */ - virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; - + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; - virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = nullptr) const override; + void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.hxx index 35dd418ed62b4f6efc5564c4fc69497dd73d000a..5f7726938e1993bfc1e9beaf007a133f34430082 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.hxx @@ -49,6 +49,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ::SharkRandomForestsMachineLearningModel() { this->m_ConfidenceIndex = true; + this->m_ProbaIndex = true; this->m_IsRegressionSupported = false; this->m_IsDoPredictBatchMultiThreaded = true; this->m_NormalizeClassLabels = true; @@ -120,20 +121,32 @@ template <class TInputValue, class TOutputValue> typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ::TargetSampleType SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> -::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType & value, ConfidenceValueType *quality, ProbaSampleType *proba) const { shark::RealVector samples(value.Size()); for(size_t i = 0; i < value.Size();i++) { samples.push_back(value[i]); } - if (quality != nullptr) - { + if (quality != nullptr || proba != nullptr) + { shark::RealVector probas = m_RFModel.decisionFunction()(samples); - (*quality) = ComputeConfidence(probas, m_ComputeMargin); + if (quality != nullptr) + { + (*quality) = ComputeConfidence(probas, m_ComputeMargin); + } + if (proba != nullptr) + { + for(size_t i =0; i< probas.size();i++) + { + //probas contain the N class probability indexed between 0 and N-1 + (*proba)[i] = static_cast<unsigned int>(probas[i]*1000); + } } + } unsigned int res{0}; m_RFModel.eval(samples, res); + TargetSampleType target; if(m_NormalizeClassLabels) { @@ -149,14 +162,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> template <class TInputValue, class TOutputValue> void SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> -::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const +::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality, ProbaListSampleType * proba) const { assert(input != nullptr); assert(targets != nullptr); assert(input->Size()==targets->Size()&&"Input sample list and target label list do not have the same size."); assert(((quality==nullptr)||(quality->Size()==input->Size()))&&"Quality samples list is not null and does not have the same size as input samples list"); - + assert((proba==nullptr)||(input->Size()==proba->Size())&&"Proba sample list and target label list do not have the same size."); + if(startIndex+size>input->Size()) { itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"["); @@ -168,11 +182,27 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> #ifdef _OPENMP omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads()); -#endif - + + #endif + if( proba !=nullptr || quality != nullptr) + { + shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples); + if( proba !=nullptr) + { + unsigned int id = startIndex; + for(shark::RealVector && p : probas.elements()) + { + ProbaSampleType prob{(unsigned int)p.size()}; + for(size_t i =0; i< p.size();i++) + { + prob[i] =p[i]*1000; + } + proba->SetMeasurementVector(id,prob); + ++id; + } + } if(quality != nullptr) { - shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples); unsigned int id = startIndex; for(shark::RealVector && p : probas.elements()) { @@ -183,7 +213,8 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ++id; } } - + } + auto prediction = m_RFModel(inputSamples); unsigned int id = startIndex; for(const auto& p : prediction.elements()) diff --git a/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx index ddc3a821d8f3bb1b9aae69c2559f9b8af543dad1..7c62782107bd4b82a064e3bafc47e1bbcecad97a 100644 --- a/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx +++ b/Modules/Learning/Supervised/src/otbCvRTreesWrapper.cxx @@ -140,7 +140,7 @@ float CvRTreesWrapper::predict_confidence(const cv::Mat& sample, float confidence = static_cast<float>(max_votes)/ntrees; return confidence; } - + #ifdef OTB_OPENCV_3 #define OTB_CV_WRAP_IMPL(type,name) \ type CvRTreesWrapper::get##name() const \ diff --git a/Modules/Learning/Supervised/test/otbSharkImageClassificationFilter.cxx b/Modules/Learning/Supervised/test/otbSharkImageClassificationFilter.cxx index 1753b05b36bcb3e96393a485707a7068e7725ee1..010781f1d25ec99c14e4f3468075344f57406057 100644 --- a/Modules/Learning/Supervised/test/otbSharkImageClassificationFilter.cxx +++ b/Modules/Learning/Supervised/test/otbSharkImageClassificationFilter.cxx @@ -100,16 +100,18 @@ void buildModel(unsigned int num_classes, unsigned int num_samples, int otbSharkImageClassificationFilter(int argc, char * argv[]) { - if(argc<5 || argc>7) + if(argc<6 || argc>8) { - std::cout << "Usage: input_image output_image output_confidence batchmode [in_model_name] [mask_name]\n"; + std::cout << "Usage: input_image output_image output_confidence output_proba batchmode [in_model_name] [mask_name]\n"; } std::string imfname = argv[1]; std::string outfname = argv[2]; std::string conffname = argv[3]; - bool batch = (std::string(argv[4])=="1"); + std::string probafname = argv[4]; + bool batch = (std::string(argv[5])=="1"); std::string modelfname = "/tmp/rf_model.txt"; std::string maskfname{}; + int num_classes = 3; MaskReaderType::Pointer mask_reader = MaskReaderType::New(); ReaderType::Pointer reader = ReaderType::New(); @@ -120,13 +122,15 @@ int otbSharkImageClassificationFilter(int argc, char * argv[]) std::cout << "Image has " << num_features << " bands\n"; - if(argc>5) + if(argc>6) { - modelfname = argv[5]; + modelfname = argv[6]; + // We don't know the number of classes, so we set it to a high number + num_classes = 10; } else { - buildModel(3, 1000, num_features, modelfname); + buildModel(num_classes, 1000, num_features, modelfname); } ClassificationFilterType::Pointer filter = ClassificationFilterType::New(); @@ -135,9 +139,10 @@ int otbSharkImageClassificationFilter(int argc, char * argv[]) model->Load(modelfname); filter->SetModel(model); filter->SetInput(reader->GetOutput()); - if(argc==7) + filter->SetNumberOfClasses(num_classes); + if(argc==8) { - maskfname = argv[6]; + maskfname = argv[7]; mask_reader->SetFileName(maskfname); filter->SetInputMask(mask_reader->GetOutput()); } @@ -148,6 +153,7 @@ int otbSharkImageClassificationFilter(int argc, char * argv[]) std::cout << "Classification\n"; filter->SetBatchMode(batch); filter->SetUseConfidenceMap(true); + filter->SetUseProbaMap(true); using TimeT = std::chrono::milliseconds; auto start = std::chrono::system_clock::now(); writer->Update(); @@ -161,5 +167,35 @@ int otbSharkImageClassificationFilter(int argc, char * argv[]) confWriter->SetFileName(conffname); confWriter->Update(); + auto probaWriter = otb::ImageFileWriter<ClassificationFilterType::ProbaImageType>::New(); + probaWriter->SetInput(filter->GetOutputProba()); + probaWriter->SetFileName(probafname); + probaWriter->Update(); + + // Check that the chosen labels correspond to the max proba + + itk::ImageRegionConstIterator<LabeledImageType> labIt(filter->GetOutput(), + filter->GetOutput()->GetLargestPossibleRegion()); + itk::ImageRegionConstIterator<ClassificationFilterType::ProbaImageType> probIt(filter->GetOutputProba(), + filter->GetOutputProba()->GetLargestPossibleRegion()); + + for (labIt.GoToBegin(), probIt.GoToBegin(); !labIt.IsAtEnd(); + ++labIt, ++probIt) + { + if(labIt.Get()>0) //Pixel is not masked + { + auto first = probIt.Get().GetDataPointer(); + auto last = probIt.Get().GetDataPointer(); + std::advance(last, num_classes); + auto max_proba = std::distance(first, std::max_element(first, last)) + 1; + if(labIt.Get() != max_proba) + { + std::cout << "Chosen label " << labIt.Get() << " and max proba position " + << max_proba << " from " << probIt.Get() << " don't match\n"; + return EXIT_FAILURE; + } + } + } + return EXIT_SUCCESS; } diff --git a/Modules/Learning/Supervised/test/tests-shark.cmake b/Modules/Learning/Supervised/test/tests-shark.cmake index 87fbcbfc48077ea8fcbf0f07a6ea25b9da195e81..9c44f9eb23f428c88043aeb2cf8951373c33075b 100644 --- a/Modules/Learning/Supervised/test/tests-shark.cmake +++ b/Modules/Learning/Supervised/test/tests-shark.cmake @@ -46,6 +46,7 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFast COMMAND otbSupervisedT ${INPUTDATA}/Classification/QB_1_ortho.tif ${TEMP}/leSharkImageClassificationFilterOutput.tif ${TEMP}/leSharkImageClassificationFilterConfidence.tif + ${TEMP}/leSharkImageClassificationFilterProba.tif 1 ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt ) @@ -72,6 +73,7 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFastMask COMMAND otbSupervi ${INPUTDATA}/Classification/QB_1_ortho.tif ${TEMP}/leSharkImageClassificationFilterWithMaskOutput.tif ${TEMP}/leSharkImageClassificationFilterWithMaskConfidence.tif + ${TEMP}/leSharkImageClassificationFilterWithMaskProba.tif 1 ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt ${INPUTDATA}/Classification/QB_1_ortho_mask.tif diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index 05a408cd35b3c66fbdac66eefdc039b49ffe3d9e..873bb5e3578b2bdc4d1b0bc3c29a88563b9ef135 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -87,8 +87,8 @@ public: typedef typename Superclass::ConfidenceValueType ConfidenceValueType; typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; - - + typedef typename Superclass::ProbaSampleType ProbaSampleType; + typedef typename Superclass::ProbaListSampleType ProbaListSampleType; typedef shark::HardClusteringModel<shark::RealVector> ClusteringModelType; typedef ClusteringModelType::OutputType ClusteringOutputType; @@ -137,11 +137,10 @@ protected: /** Predict values using the model */ virtual TargetSampleType - DoPredict(const InputSampleType &input, ConfidenceValueType *quality = nullptr) const override; - + DoPredict(const InputSampleType &input, ConfidenceValueType *quality = nullptr, ProbaSampleType *proba=nullptr) const override; virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, - TargetListSampleType *, ConfidenceListSampleType * = nullptr) const override; + TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override; template<typename DataType> DataType NormalizeData(const DataType &data) const; diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx index 1cd01e13c47a7d1187251e46dbf4d83ac109a5b5..82f16e23d7d75b924623ca67753f37038ee98dce 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.hxx @@ -102,7 +102,7 @@ template<class TInputValue, class TOutputValue> typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::TargetSampleType SharkKMeansMachineLearningModel<TInputValue, TOutputValue> -::DoPredict(const InputSampleType &value, ConfidenceValueType *quality) const +::DoPredict(const InputSampleType &value, ConfidenceValueType *quality, ProbaSampleType *proba) const { shark::RealVector data( value.Size()); for( size_t i = 0; i < value.Size(); i++ ) @@ -117,6 +117,13 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ( *quality ) = ConfidenceValueType( 1.); } + if (proba != nullptr) + { + if (!this->m_ProbaIndex) + { + itkExceptionMacro("Probability per class not available for this classifier !"); + } + } TargetSampleType target; ClusteringOutputType predictedValue = (*m_ClusteringModel)( data ); target[0] = static_cast<TOutputValue>(predictedValue); @@ -130,7 +137,8 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> const unsigned int &startIndex, const unsigned int &size, TargetListSampleType *targets, - ConfidenceListSampleType *quality) const + ConfidenceListSampleType *quality, + ProbaListSampleType * proba) const { // Perform check on input values @@ -180,6 +188,10 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> quality->SetMeasurementVector( qid, static_cast<ConfidenceValueType>(1.) ); } } + if (proba !=nullptr && !this->m_ProbaIndex) + { + itkExceptionMacro("Probability per class not available for this classifier !"); + } }