Commit e051ac1f authored by Julien Michel's avatar Julien Michel

Merge branch 'classifier_probability_output' into 'develop'

Add class probability output for RF classifiers

See merge request !286
parents 61048143 3c9831ae
......@@ -64,6 +64,7 @@ public:
typedef ClassificationFilterType::LabelType LabelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
typedef ClassificationFilterType::ConfidenceImageType ConfidenceImageType;
typedef ClassificationFilterType::ProbaImageType ProbaImageType;
protected:
......@@ -111,6 +112,7 @@ private:
SetDefaultParameterInt("nodatalabel", 0);
MandatoryOff("nodatalabel");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output image containing class labels");
SetDefaultOutputPixelType( "out", ImagePixelType_uint8);
......@@ -129,8 +131,16 @@ private:
SetDefaultOutputPixelType( "confmap", ImagePixelType_double);
MandatoryOff("confmap");
AddParameter(ParameterType_OutputImage,"probamap", "Probability map");
SetParameterDescription("probamap","Probability of each class for each pixel. This is an image having a number of bands equal to the number of classes in the model. This is only implemented for the Shark Random Forest classifier at this point.");
SetDefaultOutputPixelType("probamap",ImagePixelType_uint16);
MandatoryOff("probamap");
AddRAMParameter();
AddParameter(ParameterType_Int, "nbclasses", "Number of classes in the model");
SetDefaultParameterInt("nbclasses", 20);
SetParameterDescription("nbclasses","The number of classes is needed for the probamap output in order to set the number of output bands.");
// Doc example parameter settings
SetDocExampleParameterValue("in", "QB_1_ortho.tif");
SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
......@@ -173,7 +183,7 @@ private:
// Classify
m_ClassificationFilter = ClassificationFilterType::New();
m_ClassificationFilter->SetModel(m_Model);
m_ClassificationFilter->SetDefaultLabel(GetParameterInt("nodatalabel"));
// Normalize input image if asked
......@@ -208,9 +218,9 @@ private:
m_ClassificationFilter->SetInputMask(inMask);
}
SetParameterOutputImage<OutputImageType>("out", m_ClassificationFilter->GetOutput());
// output confidence map
if (IsParameterEnabled("confmap") && HasValue("confmap"))
{
......@@ -225,6 +235,21 @@ private:
this->DisableParameter("confmap");
}
}
if(IsParameterEnabled("probamap") && HasValue("probamap"))
{
m_ClassificationFilter->SetUseProbaMap(true);
if(m_Model->HasProbaIndex())
{
m_ClassificationFilter->SetNumberOfClasses(GetParameterInt("nbclasses"));
SetParameterOutputImage<ProbaImageType>("probamap",m_ClassificationFilter->GetOutputProba());
}
else
{
otbAppLogWARNING("Probability map requested but the classifier doesn't support it!");
this->DisableParameter("probamap");
}
}
}
ClassificationFilterType::Pointer m_ClassificationFilter;
......
......@@ -115,6 +115,7 @@ set(sharkkm_parameters "")
set(ascii_comparison --compare-ascii ${EPSILON_6})
set(raster_comparison --compare-image ${NOTOL})
set(raster_comparison_two --compare-n-images ${NOTOL} 2)
set(raster_comparison_three --compare-n-images ${NOTOL} 3)
# Reference ffiles depending on modes
set(ascii_ref_path ${OTBAPP_BASELINE_FILES})
......@@ -136,6 +137,7 @@ if(OTB_USE_SHARK)
endif()
set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF")
set(classifier_with_probamap "SHARKRF")
# This is a black list for classifier that can not have a baseline
# because they are using randomness and seed can not be set
......@@ -156,6 +158,7 @@ foreach(classifier ${classifierList})
set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format})
set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format})
set(OUTPROBAMAP cl${classifier}ProbabilityMapQB1${raster_output_format})
list(FIND classifier_without_baseline ${classifier} _classifier_has_baseline)
if(${_classifier_has_baseline} EQUAL -1)
......@@ -198,6 +201,7 @@ foreach(classifier ${classifierList})
set_tests_properties(apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 PROPERTIES DEPENDS apTvClTrainMethod${classifier}ImagesClassifierQB1)
list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap)
list(FIND classifier_with_probamap ${classifier} _classifier_has_probamap)
if(${_classifier_has_confmap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
......@@ -212,21 +216,44 @@ foreach(classifier ${classifierList})
${TEMP}/${OUTRASTER}
)
else()
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
if(${_classifier_has_probamap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
else()
message(${classifier})
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${INPUTDATA}/Classification/${OUTMODELFILE}
-imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
-nbclasses 4
-probamap ${TEMP}/${OUTPROBAMAP}
VALID ${raster_comparison_three}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
${raster_ref_path}/${OUTPROBAMAP}
${TEMP}/${OUTPROBAMAP}
)
endif()
endif()
endforeach()
......
......@@ -84,6 +84,8 @@ public:
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
/// Neural network related typedefs
typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
......@@ -162,14 +164,16 @@ protected:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = nullptr) const override;
ConfidenceListSampleType * quality = nullptr,
ProbaListSampleType * proba = nullptr) const override;
private:
/** Internal Network */
......
......@@ -375,7 +375,7 @@ AutoencoderModel<TInputValue,NeuronType>
template <class TInputValue, class NeuronType>
typename AutoencoderModel<TInputValue,NeuronType>::TargetSampleType
AutoencoderModel<TInputValue,NeuronType>
::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const
::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
......@@ -408,7 +408,8 @@ AutoencoderModel<TInputValue,NeuronType>
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType * targets,
ConfidenceListSampleType * /*quality*/) const
ConfidenceListSampleType * /*quality*/,
ProbaListSampleType * /*proba*/) const
{
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
......
......@@ -79,6 +79,9 @@ public:
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
itkNewMacro(Self);
itkTypeMacro(PCAModel, DimensionalityReductionModel);
......@@ -101,14 +104,16 @@ protected:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
virtual void DoPredictBatch(
const InputListSampleType *,
const unsigned int & startIndex,
const unsigned int & size,
TargetListSampleType *,
ConfidenceListSampleType * quality = nullptr) const override;
ConfidenceListSampleType * quality = nullptr,
ProbaListSampleType* proba = nullptr) const override;
private:
shark::LinearModel<> m_Encoder;
......
......@@ -148,7 +148,7 @@ PCAModel<TInputValue>::Load(const std::string & filename, const std::string & /*
template <class TInputValue>
typename PCAModel<TInputValue>::TargetSampleType
PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/) const
PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * /*quality*/, ProbaSampleType * /*proba*/) const
{
shark::RealVector samples(value.Size());
for(size_t i = 0; i < value.Size();i++)
......@@ -173,7 +173,7 @@ PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueT
template <class TInputValue>
void PCAModel<TInputValue>
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/) const
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * /*quality*/,ProbaListSampleType * /*proba*/) const
{
std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
......
......@@ -64,7 +64,8 @@ public:
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
typedef typename Superclass::ProbaListSampleType ProbaListSampleType;
typedef SOMMap<
itk::VariableLengthVector<TInputValue>,
itk::Statistics::EuclideanDistanceMetric<
......@@ -118,7 +119,8 @@ private:
virtual TargetSampleType DoPredict(
const InputSampleType& input,
ConfidenceValueType * quality = nullptr) const override;
ConfidenceValueType * quality = nullptr,
ProbaSampleType * proba = nullptr) const override;
/** Map size (width, height) */
SizeType m_MapSize;
......
......@@ -212,7 +212,8 @@ template <class TInputValue, unsigned int MapDimension>
typename SOMModel<TInputValue, MapDimension>::TargetSampleType
SOMModel<TInputValue, MapDimension>::DoPredict(
const InputSampleType & value,
ConfidenceValueType * /*quality*/) const
ConfidenceValueType * /*quality*/,
ProbaSampleType * /*proba*/) const
{
TargetSampleType target;
target.SetSize(this->m_Dimension);
......
......@@ -24,6 +24,7 @@
#include "itkImageToImageFilter.h"
#include "otbMachineLearningModel.h"
#include "otbImage.h"
#include "otbVectorImage.h"
namespace otb
{
......@@ -75,6 +76,11 @@ public:
typedef otb::Image<double> ConfidenceImageType;
typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType;
/**Output type for Proba */
typedef otb::VectorImage<double> ProbaImageType;
typedef typename ProbaImageType::Pointer ProbaImagePointerType;
typedef itk::VariableLengthVector<double> ProbaSampleType;
/** Set/Get the model */
itkSetObjectMacro(Model, ModelType);
itkGetObjectMacro(Model, ModelType);
......@@ -87,10 +93,16 @@ public:
itkSetMacro(UseConfidenceMap, bool);
itkGetMacro(UseConfidenceMap, bool);
/** Set/Get the proba map flag */
itkSetMacro(UseProbaMap, bool);
itkGetMacro(UseProbaMap, bool);
itkSetMacro(BatchMode, bool);
itkGetMacro(BatchMode, bool);
itkBooleanMacro(BatchMode);
itkSetMacro(NumberOfClasses, unsigned int);
itkGetMacro(NumberOfClasses, unsigned int);
/**
* If set, only pixels within the mask will be classified.
* All pixels with a value greater than 0 in the mask, will be classified.
......@@ -108,7 +120,7 @@ public:
* Get the output confidence map
*/
ConfidenceImageType * GetOutputConfidence(void);
ProbaImageType * GetOutputProba(void);
protected:
/** Constructor */
ImageClassificationFilter();
......@@ -124,6 +136,12 @@ protected:
/**PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
void GenerateOutputInformation() override
{
Superclass::GenerateOutputInformation();
// Define the number of output bands
this->GetOutputProba()->SetNumberOfComponentsPerPixel(m_NumberOfClasses);
}
private:
ImageClassificationFilter(const Self &) = delete;
void operator =(const Self&) = delete;
......@@ -134,7 +152,9 @@ private:
LabelType m_DefaultLabel;
/** Flag to produce the confidence map (if the model supports it) */
bool m_UseConfidenceMap;
bool m_UseProbaMap;
bool m_BatchMode;
unsigned int m_NumberOfClasses;
};
} // End namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
......
......@@ -99,6 +99,9 @@ public:
typedef typename MLMTargetTraits<TConfidenceValue>::SampleType ConfidenceSampleType;
typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType;
typedef itk::VariableLengthVector<double> ProbaSampleType;
typedef itk::Statistics::ListSample<ProbaSampleType> ProbaListSampleType;
/**\name Standard macros */
//@{
/** Run-time type information (and related methods). */
......@@ -114,7 +117,7 @@ public:
* quality value, or NULL
* \return The predicted label
*/
TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr) const;
TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = nullptr, ProbaSampleType *proba = nullptr) const;
/**\name Set and get the dimension of the model for dimensionality reduction models */
//@{
......@@ -130,7 +133,7 @@ public:
* Note that this method will be multi-threaded if OTB is built
* with OpenMP.
*/
typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr) const;
typename TargetListSampleType::Pointer PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const;
/**\name Classification model file manipulation */
//@{
......@@ -152,6 +155,8 @@ public:
/** Query capacity to produce a confidence index */
bool HasConfidenceIndex() const {return m_ConfidenceIndex;}
/** Query capacity to produce probability values */
bool HasProbaIndex() const {return m_ProbaIndex;}
/**\name Input list of samples accessors */
//@{
......@@ -208,6 +213,9 @@ protected:
/** flag that tells if the model support confidence index output */
bool m_ConfidenceIndex;
/** flag that tells if the model support probability output */
bool m_ProbaIndex;
/** Is DoPredictBatch multi-threaded ? */
bool m_IsDoPredictBatchMultiThreaded;
......@@ -230,7 +238,7 @@ private:
* Also set m_IsDoPredictBatchMultiThreaded to true if internal
* implementation allows for parallel batch prediction.
*/
virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr) const;
virtual void DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * target, ConfidenceListSampleType * quality = nullptr, ProbaListSampleType * proba = nullptr) const;
/** Actual implementation of single sample prediction
* \param input sample to predict
......@@ -238,7 +246,7 @@ private:
* or NULL
* \return The predicted label
*/
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr) const = 0;
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality= nullptr, ProbaSampleType *proba=nullptr) const = 0;
MachineLearningModel(const Self &) = delete;
void operator =(const Self&) = delete;
......
......@@ -38,6 +38,7 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
m_RegressionMode(false),
m_IsRegressionSupported(false),
m_ConfidenceIndex(false),
m_ProbaIndex(false),
m_IsDoPredictBatchMultiThreaded(false),
m_Dimension(0)
{}
......@@ -68,10 +69,10 @@ template <class TInputValue, class TOutputValue, class TConfidenceValue>
typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::TargetSampleType
MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::Predict(const InputSampleType& input, ConfidenceValueType *quality) const
::Predict(const InputSampleType& input, ConfidenceValueType *quality, ProbaSampleType *proba) const
{
// Call protected specialization entry point
return this->DoPredict(input,quality);
return this->DoPredict(input,quality,proba);
}
......@@ -79,8 +80,9 @@ template <class TInputValue, class TOutputValue, class TConfidenceValue>
typename MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::TargetListSampleType::Pointer
MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality) const
::PredictBatch(const InputListSampleType * input, ConfidenceListSampleType * quality, ProbaListSampleType *proba) const
{
//std::cout << "Enter batch predict" << std::endl;
typename TargetListSampleType::Pointer targets = TargetListSampleType::New();
targets->Resize(input->Size());
......@@ -89,16 +91,19 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
quality->Clear();
quality->Resize(input->Size());
}
if(proba!=ITK_NULLPTR)
{
proba->Clear();
proba->Resize(input->Size());
}
if(m_IsDoPredictBatchMultiThreaded)
{
// Simply calls DoPredictBatch
this->DoPredictBatch(input,0,input->Size(),targets,quality);
return targets;
this->DoPredictBatch(input,0,input->Size(),targets,quality,proba);
return targets;
}
else
{
#ifdef _OPENMP
// OpenMP threading here
unsigned int nb_threads(0), threadId(0), nb_batches(0);
......@@ -120,11 +125,11 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
batch_size+=input->Size()%nb_batches;
}
this->DoPredictBatch(input,batch_start,batch_size,targets,quality);
this->DoPredictBatch(input,batch_start,batch_size,targets,quality,proba);
}
}
#else
this->DoPredictBatch(input,0,input->Size(),targets,quality);
this->DoPredictBatch(input,0,input->Size(),targets,quality,proba);
#endif
return targets;
}
......@@ -135,29 +140,42 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
template <class TInputValue, class TOutputValue, class TConfidenceValue>
void
MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
::DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const
::DoPredictBatch(const InputListSampleType * input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality, ProbaListSampleType * proba) const
{
assert(input != nullptr);
assert(targets != nullptr);
assert(input->Size()==targets->Size()&&"Input sample list and target label list do not have the same size.");
assert(((quality==nullptr)||(quality->Size()==input->Size()))&&"Quality samples list is not null and does not have the same size as input samples list");
assert((proba==nullptr)||(input->Size()==proba->Size())&&"Proba sample list and target label list do not have the same size.");
if(startIndex+size>input->Size())
{
itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[");
}
if(quality != nullptr)
if (proba != nullptr)
{
for(unsigned int id = startIndex;id<startIndex+size;++id)
{
ProbaSampleType prob;
ConfidenceValueType confidence = 0;
const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence);
const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence, &prob);
quality->SetMeasurementVector(id,confidence);
proba->SetMeasurementVector(id,prob);
targets->SetMeasurementVector(id,target);
}
}
else if(quality != ITK_NULLPTR)
{
for(unsigned int id = startIndex;id<startIndex+size;++id)
{
ConfidenceValueType confidence = 0;
const TargetSampleType target = this->DoPredict(input->GetMeasurementVector(id),&confidence);
quality->SetMeasurementVector(id,confidence);
targets->SetMeasurementVector(id,target);
}
}
else
{
for(unsigned int id = startIndex;id<startIndex+size;++id)
......@@ -176,6 +194,6 @@ void
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
}
}
#endif
......@@ -53,7 +53,7 @@ public:
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(BoostMachineLearningModel, MachineLearningModel);
......@@ -124,8 +124,7 @@ protected:
~BoostMachineLearningModel() override;
/** Predict values using the model */
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override;
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override;
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
......
......@@ -104,7 +104,7 @@ template <class TInputValue, class TOutputValue>
typename BoostMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
BoostMachineLearningModel<TInputValue,TOutputValue>
::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const
::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const
{
TargetSampleType target;
......@@ -132,6 +132,8 @@ BoostMachineLearningModel<TInputValue,TOutputValue>
#endif
);
}
if (proba != nullptr && !this->m_ProbaIndex)
itkExceptionMacro("Probability per class not available for this classifier !");
target[0] = static_cast<TOutputValue>(result);
return target;
......
......@@ -65,7 +65,7 @@ public:
float predict_margin(const cv::Mat& sample,
const cv::Mat& missing =
cv::Mat()) const;
#ifdef OTB_OPENCV_3
#define OTB_CV_WRAP_PROPERTY(type,name) \
......
......@@ -53,7 +53,7 @@ public:
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(DecisionTreeMachineLearningModel, MachineLearningModel);
......@@ -179,7 +179,7 @@ protected:
~DecisionTreeMachineLearningModel() override;
/** Predict values using the model */
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override;
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override;
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
......
......@@ -117,7 +117,7 @@ template <class TInputValue, class TOutputValue>
typename DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const
::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const
{
TargetSampleType target;
......@@ -140,8 +140,10 @@ DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
itkExceptionMacro("Confidence index not available for this classifier !");
}
}
if (proba != nullptr && !this->m_ProbaIndex)
itkExceptionMacro("Probability per class not available for this classifier !");
return target;
return target;
}
template <class TInputValue, class TOutputValue>
......
......@@ -51,7 +51,7 @@ public:
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef typename Superclass::ProbaSampleType ProbaSampleType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(GradientBoostedTreeMachineLearningModel, MachineLearningModel);
......@@ -130,8 +130,7 @@ protected:
~GradientBoostedTreeMachineLearningModel() override;
/** Predict values using the model */
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override;
TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override;
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const override;
......
......@@ -83,7 +83,7 @@ template <class TInputValue, class TOutputValue>
typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::DoPredict(const InputSampleType & input, ConfidenceValueType *quality) const
::DoPredict(const InputSampleType & input, ConfidenceValueType *quality, ProbaSampleType *proba) const
{
//convert listsample to Mat
cv::Mat sample;
......@@ -103,6 +103,8 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
itkExceptionMacro("Confidence index not available for this classifier !");
}
}
if (proba != nullptr && !this->m_ProbaIndex)
itkExceptionMacro("Probability per class not available for this classifier !");
return target;