Skip to content
Snippets Groups Projects
Commit c563cee0 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: add an output to classification filter

parent fa164838
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,7 @@
#include "itkImageToImageFilter.h"
#include "otbMachineLearningModel.h"
#include "otbImage.h"
namespace otb
{
......@@ -68,6 +69,9 @@ public:
typedef MachineLearningModel<ValueType, LabelType> ModelType;
typedef typename ModelType::Pointer ModelPointerType;
typedef otb::Image<double> ConfidenceImageType;
typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType;
/** Set/Get the model */
itkSetObjectMacro(Model, ModelType);
itkGetObjectMacro(Model, ModelType);
......@@ -89,6 +93,11 @@ public:
*/
const MaskImageType * GetInputMask(void);
/**
* Get the output confidence map
*/
const ConfidenceImageType * GetOutputConfidence(void);
protected:
/** Constructor */
ImageClassificationFilter();
......@@ -110,7 +119,8 @@ private:
ModelPointerType m_Model;
/** Default label for invalid pixels (when using a mask) */
LabelType m_DefaultLabel;
/** Flag to produce the confidence map */
bool m_UseConfidenceMap;
};
} // End namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
......
......@@ -34,6 +34,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
this->SetNumberOfIndexedInputs(2);
this->SetNumberOfRequiredInputs(1);
m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
this->SetNumberOfRequiredOutputs(2);
this->SetNthOutput(0,TOutputImage::New());
this->SetNthOutput(1,ConfidenceImageType::New());
m_UseConfidenceMap = false;
}
template <class TInputImage, class TOutputImage, class TMaskImage>
......@@ -57,6 +62,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::ConfidenceImageType *
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::GetOutputConfidence()
{
if (this->GetNumberOfOutputs() < 2)
{
return 0;
}
return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
......@@ -66,6 +84,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
{
itkGenericExceptionMacro(<< "No model for classification");
}
m_UseConfidenceMap = m_Model->HasConfidenceIndex();
}
template <class TInputImage, class TOutputImage, class TMaskImage>
......@@ -77,6 +96,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
......@@ -85,6 +105,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
......@@ -97,7 +118,16 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
maskIt.GoToBegin();
}
// setup iterator for confidence map
ConfidenceMapIteratorType confidenceIt;
if (m_UseConfidenceMap)
{
confidenceIt = ConfidenceMapIteratorType(confidencePtr,outputRegionForThread);
confidenceIt.GoToBegin();
}
bool validPoint = true;
double confidenceIndex = 0.0;
// Walk the part of the image
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
......@@ -112,12 +142,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
if (validPoint)
{
// Classifify
outIt.Set(m_Model->Predict(inIt.Get())[0]);
if (m_UseConfidenceMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]);
}
else
{
outIt.Set(m_Model->Predict(inIt.Get())[0]);
}
}
else
{
// else, set default value
outIt.Set(m_DefaultLabel);
confidenceIndex = 0.0;
}
if (m_UseConfidenceMap)
{
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
progress.CompletedPixel();
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment