Commit f81b2169 authored by Jordi Inglada's avatar Jordi Inglada

ENH: Add class probability output for RF classifiers

A tag is added (as for confidence output) to all ML models.
Implementation only for Shark RF.
parent 63777ce1
......@@ -64,6 +64,7 @@ public:
typedef ClassificationFilterType::LabelType LabelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
typedef ClassificationFilterType::ConfidenceImageType ConfidenceImageType;
typedef ClassificationFilterType::ProbaImageType ProbaImageType;
protected:
......@@ -111,6 +112,7 @@ private:
SetDefaultParameterInt("nodatalabel", 0);
MandatoryOff("nodatalabel");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output image containing class labels");
SetDefaultOutputPixelType( "out", ImagePixelType_uint8);
......@@ -130,8 +132,15 @@ private:
SetDefaultOutputPixelType( "confmap", ImagePixelType_double);
MandatoryOff("confmap");
AddParameter(ParameterType_OutputImage,"probamap", "Probability map");
SetParameterDescription("probamap","");
SetDefaultOutputPixelType("probamap",ImagePixelType_uint16);
MandatoryOff("probamap");
AddRAMParameter();
AddParameter(ParameterType_Int, "classe", "number of output classes");
SetDefaultParameterInt("classe", 20);
// Doc example parameter settings
SetDocExampleParameterValue("in", "QB_1_ortho.tif");
SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
......@@ -174,7 +183,7 @@ private:
// Classify
m_ClassificationFilter = ClassificationFilterType::New();
m_ClassificationFilter->SetModel(m_Model);
m_ClassificationFilter->SetDefaultLabel(GetParameterInt("nodatalabel"));
// Normalize input image if asked
......@@ -209,9 +218,9 @@ private:
m_ClassificationFilter->SetInputMask(inMask);
}
SetParameterOutputImage<OutputImageType>("out", m_ClassificationFilter->GetOutput());
// output confidence map
if (IsParameterEnabled("confmap") && HasValue("confmap"))
{
......@@ -226,6 +235,21 @@ private:
this->DisableParameter("confmap");
}
}
if(IsParameterEnabled("probamap") && HasValue("probamap"))
{
m_ClassificationFilter->SetUseProbaMap(true);
if(m_Model->HasProbaIndex())
{
m_ClassificationFilter->SetNumberOfClasses(GetParameterInt("classe"));
SetParameterOutputImage<ProbaImageType>("probamap",m_ClassificationFilter->GetOutputProba());
}
else
{
otbAppLogWARNING("Probability map requested but the classifier doesn't support it!");
this->DisableParameter("probamap");
}
}
}
ClassificationFilterType::Pointer m_ClassificationFilter;
......
......@@ -115,6 +115,7 @@ set(sharkkm_parameters "")
set(ascii_comparison --compare-ascii ${EPSILON_6})
set(raster_comparison --compare-image ${NOTOL})
set(raster_comparison_two --compare-n-images ${NOTOL} 2)
set(raster_comparison_three --compare-n-images ${NOTOL} 3)
# Reference ffiles depending on modes
set(ascii_ref_path ${OTBAPP_BASELINE_FILES})
......@@ -136,6 +137,7 @@ if(OTB_USE_SHARK)
endif()
set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF")
set(classifier_with_probamap "SHARKRF")
# This is a black list for classifier that can not have a baseline
# because they are using randomness and seed can not be set
......@@ -156,6 +158,7 @@ foreach(classifier ${classifierList})
set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format})
set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format})
set(OUTPROBAMAP cl${classifier}ProbabilityMapQB1${raster_output_format})
list(FIND classifier_without_baseline ${classifier} _classifier_has_baseline)
if(${_classifier_has_baseline} EQUAL -1)
......@@ -198,6 +201,7 @@ foreach(classifier ${classifierList})
set_tests_properties(apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 PROPERTIES DEPENDS apTvClTrainMethod${classifier}ImagesClassifierQB1)
list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap)
list(FIND classifier_with_probamap ${classifier} _classifier_has_probamap)
if(${_classifier_has_confmap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
......@@ -212,21 +216,44 @@ foreach(classifier ${classifierList})
${TEMP}/${OUTRASTER}
)
else()
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
if(${_classifier_has_probamap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
else()
message(${classifier})
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
-classe 4
-probamap ${TEMP}/${OUTPROBAMAP}
VALID ${raster_comparison_three}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
${raster_ref_path}/${OUTPROBAMAP}
${TEMP}/${OUTPROBAMAP}
)
endif()
endif()
endforeach()
......
......@@ -82,6 +82,8 @@ public:
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
/// Neural network related typedefs
typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
......@@ -160,14 +162,16 @@ protected:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = nullptr) const override;
ConfidenceListSampleType * quality = nullptr,
ProbaListSampleType * proba = nullptr) const override;
private:
/** Internal Network */
......
......@@ -375,7 +375,7 @@ AutoencoderModel<TInputValue,NeuronType>
template <class TInputValue, class NeuronType>
typename AutoencoderModel<TInputValue,NeuronType>::TargetSampleType
AutoencoderModel<TInputValue,NeuronType>
::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const
::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
......@@ -408,7 +408,8 @@ AutoencoderModel<TInputValue,NeuronType>
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType * targets,
ConfidenceListSampleType * /*quality*/) const
ConfidenceListSampleType * /*quality*/,
ProbaListSampleType * /*proba*/) const
{
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
......
......@@ -77,6 +77,9 @@ public:
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
itkNewMacro(Self);
itkTypeMacro(PCAModel, DimensionalityReductionModel);
......@@ -99,14 +102,16 @@ protected:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = nullptr) const override;
ConfidenceListSampleType * quality = nullptr,
ProbaListSampleType* proba = nullptr) const override;
private:
shark::LinearModel<> m_Encoder;
......
......@@ -148,7 +148,7 @@ PCAModel<TInputValue>::Load(const std::string & filename, const std::string & /*
template <class TInputValue>
typename PCAModel<TInputValue>::TargetSampleType
PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const
PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
......@@ -173,7 +173,7 @@ PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueT
template <class TInputValue>
void PCAModel<TInputValue>
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/) const
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/,ProbaListSampleType * /*proba*/) const
{
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
......
......@@ -64,7 +64,8 @@ public:
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
typedef SOMMap<
itk::VariableLengthVector<TInputValue>,
itk::Statistics::EuclideanDistanceMetric<
......@@ -118,7 +119,8 @@ private:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
/** Map size (width, height) */
SizeType m_MapSize;
......
......@@ -212,7 +212,8 @@ template <class TInputValue, unsigned int MapDimension>
typename SOMModel<TInputValue, MapDimension>::TargetSampleType
SOMModel<TInputValue, MapDimension>::DoPredict(
const InputSampleType & value,
ConfidenceValueType * /*quality*/) const
ConfidenceValueType * /*quality*/,
ProbaSampleType * /*proba*/) const
{
TargetSampleType target;
target.SetSize(this->m_Dimension);
......
......@@ -24,6 +24,7 @@
#include "itkImageToImageFilter.h"
#include "otbMachineLearningModel.h"
#include "otbImage.h"
#include "otbVectorImage.h"
namespace otb
{
......@@ -75,6 +76,11 @@ public:
typedef otb::Image<double> ConfidenceImageType;
typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType;
/**Output type for Proba */
typedef otb::VectorImage<double> ProbaImageType;
typedef typename ProbaImageType::Pointer ProbaImagePointerType;
typedef itk::VariableLengthVector<double> ProbaSampleType;
/** Set/Get the model */
itkSetObjectMacro(Model, ModelType);
itkGetObjectMacro(Model, ModelType);
......@@ -87,10 +93,16 @@ public:
itkSetMacro(UseConfidenceMap, bool);
itkGetMacro(UseConfidenceMap, bool);
/** Set/Get the proba map flag */
itkSetMacro(UseProbaMap, bool);
itkGetMacro(UseProbaMap, bool);
itkSetMacro(BatchMode, bool);
itkGetMacro(BatchMode, bool);
itkBooleanMacro(BatchMode);
itkSetMacro(NumberOfClasses, unsigned int);
itkGetMacro(NumberOfClasses, unsigned int);
/**
* If set, only pixels within the mask will be classified.
* All pixels with a value greater than 0 in the mask, will be classified.
......@@ -108,7 +120,7 @@ public:
* Get the output confidence map
*/
ConfidenceImageType * GetOutputConfidence(void);
ProbaImageType * GetOutputProba(void);
protected:
/** Constructor */
ImageClassificationFilter();
......@@ -124,6 +136,12 @@ protected:
/**PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
void GenerateOutputInformation() override
{
Superclass::GenerateOutputInformation();
// Define the number of output bands
this->GetOutputProba()->SetNumberOfComponentsPerPixel(m_NumberOfClasses);
}
private:
ImageClassificationFilter(const Self &) = delete;
void operator =(const Self&) = delete;
......@@ -134,7 +152,9 @@ private:
LabelType m_DefaultLabel;
/** Flag to produce the confidence map (if the model supports it) */
bool m_UseConfidenceMap;
bool m_UseProbaMap;
bool m_BatchMode;
unsigned int m_NumberOfClasses;
};
} // End namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
......
......@@ -38,11 +38,14 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
this->SetNumberOfRequiredInputs(1);
m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
this->SetNumberOfRequiredOutputs(2);
this->SetNumberOfRequiredOutputs(3);
this->SetNthOutput(0,TOutputImage::New());
this->SetNthOutput(1,ConfidenceImageType::New());
this->SetNthOutput(2,ProbaImageType::New());
m_UseConfidenceMap = false;
m_UseProbaMap = false;
m_BatchMode = true;
m_NumberOfClasses = 1;
}
template <class TInputImage, class TOutputImage, class TMaskImage>
......@@ -79,6 +82,20 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::ProbaImageType *
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::GetOutputProba()
{
//std::cout << "Getoutprob" << std::endl;
if (this->GetNumberOfOutputs() < 2)
{
return ITK_NULLPTR;
}
return static_cast<ProbaImageType *>(this->itk::ProcessObject::GetOutput(2));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
......@@ -102,12 +119,12 @@ void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::ClassicThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{
// Get the input pointers
// Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
ProbaImagePointerType probaPtr = this->GetOutputProba();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
......@@ -116,7 +133,8 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
......@@ -137,9 +155,20 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.GoToBegin();
}
// setup iterator for proba map
bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex() && !m_Model->GetRegressionMode());
ProbaMapIteratorType probaIt;
if(computeProbaMap)
{
probaIt = ProbaMapIteratorType(probaPtr,outputRegionForThread);
probaIt.GoToBegin();
}
bool validPoint = true;
double confidenceIndex = 0.0;
ProbaSampleType probaVector{m_NumberOfClasses};
// Walk the part of the image
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
{
......@@ -153,7 +182,12 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
if (validPoint)
{
// Classifify
if (computeConfidenceMap)
if (computeProbaMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex,&probaVector)[0]);
}
else if (computeConfidenceMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]);
}
......@@ -173,6 +207,18 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
if (computeProbaMap)
{
ProbaImageType::PixelType probVect{probaVector.Size()};
probVect.Fill(0.0);
for (size_t t =0; t < probaVector.Size();t++)
{
probVect[t] = probaVector[t];
std::cout << probVect[t] << '\t' << probaVector[t] << '\n';
}
probaIt.Set(probVect);
++probaIt;
}
progress.CompletedPixel();
}
......@@ -183,14 +229,19 @@ void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::BatchThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId)
{
//std::cout << "batch mode" << std::endl;
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex()
&& !m_Model->GetRegressionMode());
bool computeProbaMap(m_UseProbaMap && m_Model->HasProbaIndex()
&& !m_Model->GetRegressionMode());
// Get the input pointers
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
ProbaImagePointerType probaPtr = this->GetOutputProba();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
......@@ -199,6 +250,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
typedef itk::ImageRegionIterator<ProbaImageType> ProbaMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
......@@ -219,7 +271,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
// typedef typename ModelType::ConfidenceValueType ConfidenceValueType;
// typedef typename ModelType::ConfidenceSampleType ConfidenceSampleType;
typedef typename ModelType::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename ModelType::ProbaListSampleType ProbaListSampleType;
typename InputListSampleType::Pointer samples = InputListSampleType::New();
unsigned int num_features = inputPtr->GetNumberOfComponentsPerPixel();
samples->SetMeasurementVectorSize(num_features);
......@@ -247,11 +299,14 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
//Make the batch prediction
typename TargetListSampleType::Pointer labels;
typename ConfidenceListSampleType::Pointer confidences;
typename ProbaListSampleType::Pointer probas;
if(computeConfidenceMap)
confidences = ConfidenceListSampleType::New();
if(computeProbaMap)
probas = ProbaListSampleType::New();
// This call is threadsafe
labels = m_Model->PredictBatch(samples,confidences);
labels = m_Model->PredictBatch(samples,confidences,probas);
// Set the output values
ConfidenceMapIteratorType confidenceIt;
......@@ -261,12 +316,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.GoToBegin();
}
ProbaMapIteratorType probaIt;
if (computeProbaMap)
{
probaIt = ProbaMapIteratorType(probaPtr,outputRegionForThread);
probaIt.GoToBegin();
}
typename TargetListSampleType::ConstIterator labIt = labels->Begin();
maskIt.GoToBegin();
for (outIt.GoToBegin(); !outIt.IsAtEnd(); ++outIt)
{
double confidenceIndex = 0.0;
TargetValueType labelValue(m_DefaultLabel);
ProbaSampleType probaValues{m_NumberOfClasses};
if (inputMaskPtr)
{
validPoint = maskIt.Get() > 0;
......@@ -280,6 +342,13 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
{
confidenceIndex = confidences->GetMeasurementVector(labIt.GetInstanceIdentifier())[0];
}
if(computeProbaMap)
{
//std::cout << "imagecfilter before get" << std::endl;
probaValues = probas->GetMeasurementVector(labIt.GetInstanceIdentifier());
//std::cout << "imageclfilt after get" << std::endl;
}
++labIt;
}
......@@ -295,7 +364,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
if(computeProbaMap)
{
probaIt.Set(probaValues);
++probaIt;
}
progress.CompletedPixel();
}
}
......
......@@ -99,6 +99,9 @@ public:
typedef typename MLMTargetTraits<TConfidenceValue>::SampleType ConfidenceSampleType;
typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType;
typedef itk::VariableLengthVector<double> ProbaSampleType;
typedef itk::Statistics::ListSample<ProbaSampleType> ProbaListSampleType;
/**\name Standard macros */
//@{
/** Run-time type information (and related methods). */
......@@ -114,7 +117,7 @@ public:
* quality value, or NULL
* \return The predicted label
*/
TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr) const;
TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr, ProbaSampleType *proba = nullptr) const;
/**\name Set and get the dimension of the model for dimensionality reduction models */
//@{
......@@ -130,7 +133,7 @@ public:
* Note that this method will be multi-threaded if OTB is built
* with OpenMP.
*/
typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr) const;
typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const;
/**\name Classification model file manipulation */
//@{
......@@ -152,6 +155,8 @@ public:
/** Query capacity to produce a confidence index */
bool HasConfidenceIndex() const {return m_ConfidenceIndex;}
/** Query capacity to produce probability values */
bool HasProbaIndex() const {return m_ProbaIndex;}
/**\name Input list of samples accessors */
//@{
......@@ -208,6 +213,9 @@ protected:
/** flag that tells if the model support confidence index output */
bool m_ConfidenceIndex;
/** flag that tells if the model support probability output */
bool m_ProbaIndex;
/** Is DoPredictBatch multi-threaded ? */
bool m_IsDoPredictBatchMultiThreaded;
......@@ -230,7 +238,7 @@ private:
* Also set m_IsDoPredictBatchMultiThreaded to true if internal
* implementation allows for parallel batch prediction.
*/
virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr) const;
virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const;
/** Actual implementation of single sample prediction
* \param input sample to predict
......@@ -238,7 +246,7 @@ private:
* or NULL
* \return The predicted label
*/
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr) const = 0;
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr, ProbaSampleType *proba=nullptr) const = 0;
MachineLearningModel(const Self &) = delete;
void operator =(const Self&) = delete;
......
......@@ -38,6 +38,7 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
m_RegressionMode(false),
m_IsRegressionSupported(false),
m_ConfidenceIndex(false),
m_ProbaIndex(false),
m_IsDoPredictBatchMultiThreaded(false),
m_Dimension(0)
{}
......@@ -68,10 +69,10 @@ template <class TInputValue, class TOutputValue, class TConfidenceValue>
typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::TargetSampleType
MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::Predict(const InputSampleType& input, ConfidenceValueType *quality) const
::Predict(const InputSampleType& input, ConfidenceValueType *quality, ProbaSampleType *proba) const
{
// Call protected specialization entry point
return this->DoPredict(input,quality);
return this->DoPredict(input,quality,proba);
}
......@@ -79,8 +80,9 @@ template <class TInputValue, class TOutputValue, class TConfidenceValue>
typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::TargetListSampleType::Pointer
MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality) const
::PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality, ProbaListSampleType *proba) const
{
//std::cout << "Enter batch predict" << std::endl;
typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
targets->Resize(input->Size());
......@@ -89,16 +91,19 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
quality->Clear();
quality->Resize(input->Size());
}
if(proba!=ITK_NULLPTR)
{
proba->Clear();
proba->Resize(input->Size());
}
if(m_IsDoPredictBatchMultiThreaded)
{
// Simply calls DoPredictBatch
this->DoPredictBatch(input,0,input->Size(),