Commit e051ac1f authored by Julien Michel's avatar Julien Michel

Merge branch 'classifier_probability_output' into 'develop'

Add class probability output for RF classifiers

See merge request !286
parents 61048143 3c9831ae
......@@ -64,6 +64,7 @@ public:
typedef ClassificationFilterType::LabelType LabelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
typedef ClassificationFilterType::ConfidenceImageType ConfidenceImageType;
typedef ClassificationFilterType::ProbaImageType ProbaImageType;
protected:
......@@ -111,6 +112,7 @@ private:
SetDefaultParameterInt("nodatalabel", 0);
MandatoryOff("nodatalabel");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output image containing class labels");
SetDefaultOutputPixelType( "out", ImagePixelType_uint8);
......@@ -129,8 +131,16 @@ private:
SetDefaultOutputPixelType( "confmap", ImagePixelType_double);
MandatoryOff("confmap");
AddParameter(ParameterType_OutputImage,"probamap", "Probability map");
SetParameterDescription("probamap","Probability of each class for each pixel. This is an image having a number of bands equal to the number of classes in the model. This is only implemented for the Shark Random Forest classifier at this point.");
SetDefaultOutputPixelType("probamap",ImagePixelType_uint16);
MandatoryOff("probamap");
AddRAMParameter();
AddParameter(ParameterType_Int, "nbclasses", "Number of classes in the model");
SetDefaultParameterInt("nbclasses", 20);
SetParameterDescription("nbclasses","The number of classes is needed for the probamap output in order to set the number of output bands.");
// Doc example parameter settings
SetDocExampleParameterValue("in", "QB_1_ortho.tif");
SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
......@@ -173,7 +183,7 @@ private:
// Classify
m_ClassificationFilter = ClassificationFilterType::New();
m_ClassificationFilter->SetModel(m_Model);
m_ClassificationFilter->SetDefaultLabel(GetParameterInt("nodatalabel"));
// Normalize input image if asked
......@@ -208,9 +218,9 @@ private:
m_ClassificationFilter->SetInputMask(inMask);
}
SetParameterOutputImage<OutputImageType>("out", m_ClassificationFilter->GetOutput());
// output confidence map
if (IsParameterEnabled("confmap") && HasValue("confmap"))
{
......@@ -225,6 +235,21 @@ private:
this->DisableParameter("confmap");
}
}
if(IsParameterEnabled("probamap") && HasValue("probamap"))
{
m_ClassificationFilter->SetUseProbaMap(true);
if(m_Model->HasProbaIndex())
{
m_ClassificationFilter->SetNumberOfClasses(GetParameterInt("nbclasses"));
SetParameterOutputImage<ProbaImageType>("probamap",m_ClassificationFilter->GetOutputProba());
}
else
{
otbAppLogWARNING("Probability map requested but the classifier doesn't support it!");
this->DisableParameter("probamap");
}
}
}
ClassificationFilterType::Pointer m_ClassificationFilter;
......
......@@ -115,6 +115,7 @@ set(sharkkm_parameters "")
set(ascii_comparison --compare-ascii ${EPSILON_6})
set(raster_comparison --compare-image ${NOTOL})
set(raster_comparison_two --compare-n-images ${NOTOL} 2)
set(raster_comparison_three --compare-n-images ${NOTOL} 3)
# Reference ffiles depending on modes
set(ascii_ref_path ${OTBAPP_BASELINE_FILES})
......@@ -136,6 +137,7 @@ if(OTB_USE_SHARK)
endif()
set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF")
set(classifier_with_probamap "SHARKRF")
# This is a black list for classifier that can not have a baseline
# because they are using randomness and seed can not be set
......@@ -156,6 +158,7 @@ foreach(classifier ${classifierList})
set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format})
set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format})
set(OUTPROBAMAP cl${classifier}ProbabilityMapQB1${raster_output_format})
list(FIND classifier_without_baseline ${classifier} _classifier_has_baseline)
if(${_classifier_has_baseline} EQUAL -1)
......@@ -198,6 +201,7 @@ foreach(classifier ${classifierList})
set_tests_properties(apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 PROPERTIES DEPENDS apTvClTrainMethod${classifier}ImagesClassifierQB1)
list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap)
list(FIND classifier_with_probamap ${classifier} _classifier_has_probamap)
if(${_classifier_has_confmap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
......@@ -212,21 +216,44 @@ foreach(classifier ${classifierList})
${TEMP}/${OUTRASTER}
)
else()
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
if(${_classifier_has_probamap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
else()
message(${classifier})
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
-nbclasses 4
-probamap ${TEMP}/${OUTPROBAMAP}
VALID ${raster_comparison_three}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
${raster_ref_path}/${OUTPROBAMAP}
${TEMP}/${OUTPROBAMAP}
)
endif()
endif()
endforeach()
......
......@@ -84,6 +84,8 @@ public:
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
/// Neural network related typedefs
typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
......@@ -162,14 +164,16 @@ protected:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = nullptr) const override;
ConfidenceListSampleType * quality = nullptr,
ProbaListSampleType * proba = nullptr) const override;
private:
/** Internal Network */
......
......@@ -375,7 +375,7 @@ AutoencoderModel<TInputValue,NeuronType>
template <class TInputValue, class NeuronType>
typename AutoencoderModel<TInputValue,NeuronType>::TargetSampleType
AutoencoderModel<TInputValue,NeuronType>
::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const
::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
......@@ -408,7 +408,8 @@ AutoencoderModel<TInputValue,NeuronType>
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType * targets,
ConfidenceListSampleType * /*quality*/) const
ConfidenceListSampleType * /*quality*/,
ProbaListSampleType * /*proba*/) const
{
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
......
......@@ -79,6 +79,9 @@ public:
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
itkNewMacro(Self);
itkTypeMacro(PCAModel, DimensionalityReductionModel);
......@@ -101,14 +104,16 @@ protected:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = nullptr) const override;
ConfidenceListSampleType * quality = nullptr,
ProbaListSampleType* proba = nullptr) const override;
private:
shark::LinearModel<> m_Encoder;
......
......@@ -148,7 +148,7 @@ PCAModel<TInputValue>::Load(const std::string & filename, const std::string & /*
template <class TInputValue>
typename PCAModel<TInputValue>::TargetSampleType
PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const
PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
......@@ -173,7 +173,7 @@ PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueT
template <class TInputValue>
void PCAModel<TInputValue>
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/) const
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/,ProbaListSampleType * /*proba*/) const
{
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
......
......@@ -64,7 +64,8 @@ public:
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
typedef SOMMap<
itk::VariableLengthVector<TInputValue>,
itk::Statistics::EuclideanDistanceMetric<
......@@ -118,7 +119,8 @@ private:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
/** Map size (width, height) */
SizeType m_MapSize;
......
......@@ -212,7 +212,8 @@ template <class TInputValue, unsigned int MapDimension>
typename SOMModel<TInputValue, MapDimension>::TargetSampleType
SOMModel<TInputValue, MapDimension>::DoPredict(
const InputSampleType & value,
ConfidenceValueType * /*quality*/) const
ConfidenceValueType * /*quality*/,
ProbaSampleType * /*proba*/) const
{
TargetSampleType target;
target.SetSize(this->m_Dimension);
......
......@@ -24,6 +24,7 @@
#include "itkImageToImageFilter.h"
#include "otbMachineLearningModel.h"
#include "otbImage.h"
#include "otbVectorImage.h"
namespace otb
{
......@@ -75,6 +76,11 @@ public:
typedef otb::Image<double> ConfidenceImageType;
typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType;
/**Output type for Proba */
typedef otb::VectorImage<double> ProbaImageType;
typedef typename ProbaImageType::Pointer ProbaImagePointerType;
typedef itk::VariableLengthVector<double> ProbaSampleType;
/** Set/Get the model */
itkSetObjectMacro(Model, ModelType);
itkGetObjectMacro(Model, ModelType);
......@@ -87,10 +93,16 @@ public:
itkSetMacro(UseConfidenceMap, bool);
itkGetMacro(UseConfidenceMap, bool);
/** Set/Get the proba map flag */
itkSetMacro(UseProbaMap, bool);
itkGetMacro(UseProbaMap, bool);
itkSetMacro(BatchMode, bool);
itkGetMacro(BatchMode, bool);
itkBooleanMacro(BatchMode);
itkSetMacro(NumberOfClasses, unsigned int);
itkGetMacro(NumberOfClasses, unsigned int);
/**
* If set, only pixels within the mask will be classified.
* All pixels with a value greater than 0 in the mask, will be classified.
......@@ -108,7 +120,7 @@ public:
* Get the output confidence map
*/
ConfidenceImageType * GetOutputConfidence(void);
ProbaImageType * GetOutputProba(void);
protected:
/** Constructor */
ImageClassificationFilter();
......@@ -124,6 +136,12 @@ protected:
/**PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
void GenerateOutputInformation() override
{
Superclass::GenerateOutputInformation();
// Define the number of output bands
this->GetOutputProba()->SetNumberOfComponentsPerPixel(m_NumberOfClasses);
}
private:
ImageClassificationFilter(const Self &) = delete;
void operator =(const Self&) = delete;
......@@ -134,7 +152,9 @@ private:
LabelType m_DefaultLabel;
/** Flag to produce the confidence map (if the model supports it) */
bool m_UseConfidenceMap;
bool m_UseProbaMap;
bool m_BatchMode;
unsigned int m_NumberOfClasses;
};
} // End namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
......
......@@ -38,11 +38,14 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
this->SetNumberOfRequiredInputs(1);
m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
this->SetNumberOfRequiredOutputs(2);
this->SetNumberOfRequiredOutputs(3);
this->SetNthOutput(0,TOutputImage::New());
this->SetNthOutput(1,ConfidenceImageType::New());
this->SetNthOutput(2,ProbaImageType::New());
m_UseConfidenceMap = false;
m_UseProbaMap = false;
m_BatchMode = true;
m_NumberOfClasses = 1;
}
template <class TInputImage, class TOutputImage, class TMaskImage>
......@@ -79,6 +82,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::ProbaImageType *
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::GetOutputProba()
{
if (this->GetNumberOfOutputs() < 2)
{
return nullptr;
}
return static_cast<ProbaImageType *>(this->itk::ProcessObject::GetOutput(2));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
......@@ -100,22 +116,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
itk::ThreadIdType threadId)
{
// Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
ProbaImagePointerType probaPtr = this->GetOutputProba();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
itk::ProgressReporter progress(this, threadId,
outputRegionForThread.GetNumberOfPixels());
// Define iterators
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
......@@ -129,7 +148,8 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
}
// setup iterator for confidence map
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() && !m_Model->GetRegressionMode());
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex() &&
!m_Model->GetRegressionMode());
ConfidenceMapIteratorType confidenceIt;
if (computeConfidenceMap)
{
......@@ -137,11 +157,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.GoToBegin();
}
// setup iterator for proba map
bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() &&
!m_Model->GetRegressionMode());
ProbaMapIteratorType probaIt;
if(computeProbaMap)
{
probaIt = ProbaMapIteratorType(probaPtr,outputRegionForThread);
probaIt.GoToBegin();
}
bool validPoint = true;
double confidenceIndex = 0.0;
ProbaSampleType probaVector{m_NumberOfClasses};
probaVector.Fill(0);
// Walk the part of the image
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd();
++inIt, ++outIt)
{
// Check pixel validity
if (inputMaskPtr)
......@@ -151,17 +185,22 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
}
// If point is valid
if (validPoint)
{
{
// Classifify
if (computeConfidenceMap)
{
if (computeProbaMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex,
&probaVector)[0]);
}
else if (computeConfidenceMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]);
}
}
else
{
{
outIt.Set(m_Model->Predict(inIt.Get())[0]);
}
}
}
else
{
// else, set default value
......@@ -173,6 +212,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
if (computeProbaMap)
{
probaIt.Set(probaVector);
++probaIt;
}
progress.CompletedPixel();
}
......@@ -181,24 +225,31 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread,
itk::ThreadIdType threadId)
{
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex()
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex()
&& !m_Model->GetRegressionMode());
bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex()
&& !m_Model->GetRegressionMode());
// Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
ProbaImagePointerType probaPtr = this->GetOutputProba();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
itk::ProgressReporter progress(this, threadId,
outputRegionForThread.GetNumberOfPixels());
// Define iterators
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
......@@ -210,16 +261,12 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
maskIt.GoToBegin();
}
// typedef typename ModelType::InputValueType InputValueType;
typedef typename ModelType::InputSampleType InputSampleType;
typedef typename ModelType::InputListSampleType InputListSampleType;
typedef typename ModelType::TargetValueType TargetValueType;
// typedef typename ModelType::TargetSampleType TargetSampleType;
typedef typename ModelType::TargetListSampleType TargetListSampleType;
// typedef typename ModelType::ConfidenceValueType ConfidenceValueType;
// typedef typename ModelType::ConfidenceSampleType ConfidenceSampleType;
typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename ModelType::ProbaListSampleType ProbaListSampleType;
typename InputListSampleType::Pointer samples = InputListSampleType::New();
unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
samples->SetMeasurementVectorSize(num_features);
......@@ -247,11 +294,14 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
//Make the batch prediction
typename TargetListSampleType::Pointer labels;
typename ConfidenceListSampleType::Pointer confidences;
typename ProbaListSampleType::Pointer probas;
if(computeConfidenceMap)
confidences = ConfidenceListSampleType::New();
if(computeProbaMap)
probas = ProbaListSampleType::New();
// This call is threadsafe
labels = m_Model->PredictBatch(samples,confidences);
labels = m_Model->PredictBatch(samples,confidences,probas);
// Set the output values
ConfidenceMapIteratorType confidenceIt;
......@@ -261,12 +311,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.GoToBegin();
}
ProbaMapIteratorType probaIt;
if (computeProbaMap)
{
probaIt = ProbaMapIteratorType(probaPtr,outputRegionForThread);
probaIt.GoToBegin();
}
typename TargetListSampleType::ConstIterator labIt = labels->Begin();
maskIt.GoToBegin();
for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
{
double confidenceIndex = 0.0;
TargetValueType labelValue(m_DefaultLabel);
ProbaSampleType probaValues{m_NumberOfClasses};
if (inputMaskPtr)
{
validPoint = maskIt.Get() > 0;
......@@ -278,16 +335,26 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
if(computeConfidenceMap)
{
confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
confidenceIndex =
confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
}