diff --git a/Modules/Learning/Supervised/include/otbImageClassificationFilter.h b/Modules/Learning/Supervised/include/otbImageClassificationFilter.h index 2c79315b3d5ca1e127d3c7a0ead88688fe14c5e4..1f34888b054a1ee01e690bd8056a86b517ec1f80 100644 --- a/Modules/Learning/Supervised/include/otbImageClassificationFilter.h +++ b/Modules/Learning/Supervised/include/otbImageClassificationFilter.h @@ -20,6 +20,7 @@ #include "itkImageToImageFilter.h" #include "otbMachineLearningModel.h" +#include "otbImage.h" namespace otb { @@ -68,6 +69,9 @@ public: typedef MachineLearningModel<ValueType, LabelType> ModelType; typedef typename ModelType::Pointer ModelPointerType; + typedef otb::Image<double> ConfidenceImageType; + typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType; + /** Set/Get the model */ itkSetObjectMacro(Model, ModelType); itkGetObjectMacro(Model, ModelType); @@ -89,6 +93,11 @@ public: */ const MaskImageType * GetInputMask(void); + /** + * Get the output confidence map + */ + const ConfidenceImageType * GetOutputConfidence(void); + protected: /** Constructor */ ImageClassificationFilter(); @@ -110,7 +119,8 @@ private: ModelPointerType m_Model; /** Default label for invalid pixels (when using a mask) */ LabelType m_DefaultLabel; - + /** Flag to produce the confidence map */ + bool m_UseConfidenceMap; }; } // End namespace otb #ifndef OTB_MANUAL_INSTANTIATION diff --git a/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx b/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx index 06225af19e37ce45feb5e6008121f8c06f0bd05a..ccb5ae616239b87cd4734deb71714ee30d21d25e 100644 --- a/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx +++ b/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx @@ -34,6 +34,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> this->SetNumberOfIndexedInputs(2); this->SetNumberOfRequiredInputs(1); m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue(); + + this->SetNumberOfRequiredOutputs(2); + this->SetNthOutput(0,TOutputImage::New()); + this->SetNthOutput(1,ConfidenceImageType::New()); + m_UseConfidenceMap = false; } template <class TInputImage, class TOutputImage, class TMaskImage> @@ -57,6 +62,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1)); } +template <class TInputImage, class TOutputImage, class TMaskImage> +typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ConfidenceImageType * +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::GetOutputConfidence() +{ + if (this->GetNumberOfOutputs() < 2) + { + return 0; + } + return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1)); +} + template <class TInputImage, class TOutputImage, class TMaskImage> void ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> @@ -66,6 +84,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> { itkGenericExceptionMacro(<< "No model for classification"); } + m_UseConfidenceMap = m_Model->HasConfidenceIndex(); } template <class TInputImage, class TOutputImage, class TMaskImage> @@ -77,6 +96,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> InputImageConstPointerType inputPtr = this->GetInput(); MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); OutputImagePointerType outputPtr = this->GetOutput(); + ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence(); // Progress reporting itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); @@ -85,6 +105,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; + typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType; InputIteratorType inIt(inputPtr, outputRegionForThread); OutputIteratorType outIt(outputPtr, outputRegionForThread); @@ -97,7 +118,16 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> maskIt.GoToBegin(); } + // setup iterator for confidence map + ConfidenceMapIteratorType confidenceIt; + if (m_UseConfidenceMap) + { + confidenceIt = ConfidenceMapIteratorType(confidencePtr,outputRegionForThread); + confidenceIt.GoToBegin(); + } + bool validPoint = true; + double confidenceIndex = 0.0; // Walk the part of the image for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt) @@ -112,12 +142,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> if (validPoint) { // Classifify - outIt.Set(m_Model->Predict(inIt.Get())[0]); + if (m_UseConfidenceMap) + { + outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]); + } + else + { + outIt.Set(m_Model->Predict(inIt.Get())[0]); + } } else { // else, set default value outIt.Set(m_DefaultLabel); + confidenceIndex = 0.0; + } + if (m_UseConfidenceMap) + { + confidenceIt.Set(confidenceIndex); + ++confidenceIt; } progress.CompletedPixel(); }