Commit 7c130ac4 authored by Guillaume Pasero's avatar Guillaume Pasero

Merging feature Classifier quality index (Jira 847) into develop

parents b0a44d57 62a1e76a
......@@ -59,6 +59,7 @@ public:
typedef ClassificationFilterType::ValueType ValueType;
typedef ClassificationFilterType::LabelType LabelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
typedef ClassificationFilterType::ConfidenceImageType ConfidenceImageType;
private:
void DoInit()
......@@ -94,6 +95,21 @@ private:
SetParameterDescription( "out", "Output image containing class labels");
SetParameterOutputImagePixelType( "out", ImagePixelType_uint8);
AddParameter(ParameterType_OutputImage, "confmap", "Confidence map");
SetParameterDescription( "confmap", "Confidence map of the produced classification. The confidence index depends on the model : \n"
" - LibSVM : difference between the two highest probabilities (needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n"
" - OpenCV\n"
" * Boost : sum of votes\n"
" * DecisionTree : (not supported)\n"
" * GradientBoostedTree : (not supported)\n"
" * KNearestNeighbors : number of neighbors with the same label\n"
" * NeuralNetwork : difference between the two highest responses\n"
" * NormalBayes : (not supported)\n"
" * RandomForest : proportion of decision trees that classified the sample to the second class (only works for 2-class models)\n"
" * SVM : distance to margin (only works for 2-class models)\n");
SetParameterOutputImagePixelType( "confmap", ImagePixelType_double);
MandatoryOff("confmap");
AddRAMParameter();
// Doc example parameter settings
......@@ -171,6 +187,21 @@ private:
}
SetParameterOutputImage<OutputImageType>("out", m_ClassificationFilter->GetOutput());
// output confidence map
if (IsParameterEnabled("confmap") && HasValue("confmap"))
{
m_ClassificationFilter->SetUseConfidenceMap(true);
if (m_Model->HasConfidenceIndex())
{
SetParameterOutputImage<ConfidenceImageType>("confmap",m_ClassificationFilter->GetOutputConfidence());
}
else
{
otbAppLogWARNING("Confidence map requested but the classifier doesn't support it!");
this->DisableParameter("confmap");
}
}
}
ClassificationFilterType::Pointer m_ClassificationFilter;
......
......@@ -42,6 +42,9 @@ namespace Wrapper
AddParameter(ParameterType_Empty, "classifier.libsvm.opt", "Parameters optimization");
MandatoryOff("classifier.libsvm.opt");
SetParameterDescription("classifier.libsvm.opt", "SVM parameters optimization flag.");
AddParameter(ParameterType_Empty, "classifier.libsvm.prob", "Probability estimation");
MandatoryOff("classifier.libsvm.prob");
SetParameterDescription("classifier.libsvm.prob", "Probability estimation flag.");
}
......@@ -56,6 +59,10 @@ namespace Wrapper
{
libSVMClassifier->SetParameterOptimization(true);
}
if (IsParameterEnabled("classifier.libsvm.prob"))
{
libSVMClassifier->SetDoProbabilityEstimates(true);
}
libSVMClassifier->SetC(GetParameterFloat("classifier.libsvm.c"));
switch (GetParameterInt("classifier.libsvm.k"))
......
......@@ -77,7 +77,7 @@ set(rf_output_format ".rf")
set(knn_output_format ".knn")
# Training algorithms parameters
set(libsvm_parameters "-classifier.libsvm.opt" "true")
set(libsvm_parameters "-classifier.libsvm.opt" "true" "-classifier.libsvm.prob" "true")
set(svm_parameters "-classifier.svm.opt" "true")
set(boost_parameters "")
set(dt_parameters "")
......@@ -90,6 +90,7 @@ set(knn_parameters "")
# Validation depending on mode
set(ascii_comparison --compare-ascii ${NOTOL})
set(raster_comparison --compare-image ${NOTOL})
set(raster_comparison_two --compare-n-images ${NOTOL} 2)
# Reference ffiles depending on modes
set(ascii_ref_path ${OTBAPP_BASELINE_FILES})
......@@ -102,6 +103,7 @@ endif()
if(OTB_USE_OPENCV)
list(APPEND classifierList "SVM" "BOOST" "DT" "GBT" "ANN" "BAYES" "RF" "KNN")
endif()
set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN")
# Loop on classifiers
foreach(classifier ${classifierList})
......@@ -110,6 +112,7 @@ foreach(classifier ${classifierList})
# Derive output file name
set(OUTMODELFILE cl${classifier}_ModelQB1${${lclassifier}_output_format})
set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format})
set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format})
otb_test_application(
NAME apTvClTrainMethod${classifier}ImagesClassifierQB1
......@@ -160,7 +163,9 @@ foreach(classifier ${classifierList})
#set_tests_properties(apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 PROPERTIES DEPENDS apTvClTrainMethod${classifier}ImagesClassifierQB1_OutXML1)
otb_test_application(
list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap)
if(${_classifier_has_confmap} EQUAL -1)
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
......@@ -172,6 +177,23 @@ foreach(classifier ${classifierList})
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
)
else()
otb_test_application(
NAME apTvClMethod${classifier}ImageClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format}
-model ${ascii_ref_path}/${OUTMODELFILE}
-imstat ${ascii_ref_path}/clImageStatisticsQB1${stat_input_format}
-out ${TEMP}/${OUTRASTER} ${raster_output_option}
-confmap ${TEMP}/${OUTCONFMAP}
VALID ${raster_comparison_two}
${raster_ref_path}/${OUTRASTER}
${TEMP}/${OUTRASTER}
${raster_ref_path}/${OUTCONFMAP}
${TEMP}/${OUTCONFMAP}
)
endif()
endforeach()
......
......@@ -307,6 +307,12 @@ public:
return static_cast<bool>(m_Parameters.probability);
}
/** Test if the model has probabilities */
bool HasProbabilities(void) const
{
return static_cast<bool>(svm_check_probability_model(m_Model));
}
/** Return number of support vectors */
int GetNumberOfSupportVectors(void) const
{
......
......@@ -464,7 +464,7 @@ SVMModel<TValue, TLabel>::EvaluateProbabilities(const MeasurementType& measure)
itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities");
}
if (svm_check_probability_model(m_Model) == 0)
if (!this->HasProbabilities())
{
throw itk::ExceptionObject(__FILE__, __LINE__,
"Model does not support probability estimates", ITK_LOCATION);
......
......@@ -40,19 +40,17 @@ public:
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(BoostMachineLearningModel, itk::MachineLearningModel);
itkTypeMacro(BoostMachineLearningModel, MachineLearningModel);
/** Setters/Getters to the Boost type
* It can be CvBoost::DISCRETE, CvBoost::REAL, CvBoost::LOGIT, CvBoost::GENTLE
......@@ -122,7 +120,7 @@ protected:
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
private:
BoostMachineLearningModel(const Self &); //purposely not implemented
......
......@@ -37,6 +37,7 @@ BoostMachineLearningModel<TInputValue,TOutputValue>
m_SplitCrit(CvBoost::DEFAULT),
m_MaxDepth(1)
{
this->m_ConfidenceIndex = true;
}
......@@ -76,7 +77,7 @@ template <class TInputValue, class TOutputValue>
typename BoostMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
BoostMachineLearningModel<TInputValue,TOutputValue>
::PredictClassification(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
{
//convert listsample to Mat
cv::Mat sample;
......@@ -91,6 +92,12 @@ BoostMachineLearningModel<TInputValue,TOutputValue>
target[0] = static_cast<TOutputValue>(result);
if (quality != NULL)
{
(*quality) = static_cast<ConfidenceValueType>(
m_BoostModel->predict(sample,missing,cv::Range::all(),false,true));
}
return target;
}
......
......@@ -40,19 +40,17 @@ public:
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(DecisionTreeMachineLearningModel, itk::MachineLearningModel);
itkTypeMacro(DecisionTreeMachineLearningModel, MachineLearningModel);
/** Setters/Getters to the maximum depth of the tree.
* The maximum possible depth of the tree. That is the training algorithms attempts to split a node
......@@ -183,7 +181,7 @@ protected:
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
private:
DecisionTreeMachineLearningModel(const Self &); //purposely not implemented
......
......@@ -83,7 +83,7 @@ template <class TInputValue, class TOutputValue>
typename DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::PredictClassification(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
{
//convert listsample to Mat
cv::Mat sample;
......@@ -96,6 +96,14 @@ DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
target[0] = static_cast<TOutputValue>(result);
if (quality != NULL)
{
if (!this->m_ConfidenceIndex)
{
itkExceptionMacro("Confidence index not available for this classifier !");
}
}
return target;
}
......
......@@ -40,19 +40,17 @@ public:
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(GradientBoostedTreeMachineLearningModel, itk::MachineLearningModel);
itkTypeMacro(GradientBoostedTreeMachineLearningModel, MachineLearningModel);
/** Type of the loss function used for training.
* It must be one of the following types: CvGBTrees::SQUARED_LOSS, CvGBTrees::ABSOLUTE_LOSS,
......@@ -130,7 +128,7 @@ protected:
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
private:
GradientBoostedTreeMachineLearningModel(const Self &); //purposely not implemented
......
......@@ -80,7 +80,7 @@ template <class TInputValue, class TOutputValue>
typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::PredictClassification(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
{
//convert listsample to Mat
cv::Mat sample;
......@@ -93,6 +93,14 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
target[0] = static_cast<TOutputValue>(result);
if (quality != NULL)
{
if (!this->m_ConfidenceIndex)
{
itkExceptionMacro("Confidence index not available for this classifier !");
}
}
return target;
}
......
......@@ -20,6 +20,7 @@
#include "itkImageToImageFilter.h"
#include "otbMachineLearningModel.h"
#include "otbImage.h"
namespace otb
{
......@@ -68,6 +69,9 @@ public:
typedef MachineLearningModel<ValueType, LabelType> ModelType;
typedef typename ModelType::Pointer ModelPointerType;
typedef otb::Image<double> ConfidenceImageType;
typedef typename ConfidenceImageType::Pointer ConfidenceImagePointerType;
/** Set/Get the model */
itkSetObjectMacro(Model, ModelType);
itkGetObjectMacro(Model, ModelType);
......@@ -76,6 +80,10 @@ public:
itkSetMacro(DefaultLabel, LabelType);
itkGetMacro(DefaultLabel, LabelType);
/** Set/Get the confidence map flag */
itkSetMacro(UseConfidenceMap, bool);
itkGetMacro(UseConfidenceMap, bool);
/**
* If set, only pixels within the mask will be classified.
* All pixels with a value greater than 0 in the mask, will be classified.
......@@ -89,6 +97,11 @@ public:
*/
const MaskImageType * GetInputMask(void);
/**
* Get the output confidence map
*/
ConfidenceImageType * GetOutputConfidence(void);
protected:
/** Constructor */
ImageClassificationFilter();
......@@ -110,7 +123,8 @@ private:
ModelPointerType m_Model;
/** Default label for invalid pixels (when using a mask) */
LabelType m_DefaultLabel;
/** Flag to produce the confidence map (if the model supports it) */
bool m_UseConfidenceMap;
};
} // End namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
......
......@@ -34,6 +34,11 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
this->SetNumberOfIndexedInputs(2);
this->SetNumberOfRequiredInputs(1);
m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
this->SetNumberOfRequiredOutputs(2);
this->SetNthOutput(0,TOutputImage::New());
this->SetNthOutput(1,ConfidenceImageType::New());
m_UseConfidenceMap = false;
}
template <class TInputImage, class TOutputImage, class TMaskImage>
......@@ -57,6 +62,19 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::ConfidenceImageType *
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
::GetOutputConfidence()
{
if (this->GetNumberOfOutputs() < 2)
{
return 0;
}
return static_cast<ConfidenceImageType *>(this->itk::ProcessObject::GetOutput(1));
}
template <class TInputImage, class TOutputImage, class TMaskImage>
void
ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
......@@ -77,6 +95,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
InputImageConstPointerType inputPtr = this->GetInput();
MaskImageConstPointerType inputMaskPtr = this->GetInputMask();
OutputImagePointerType outputPtr = this->GetOutput();
ConfidenceImagePointerType confidencePtr = this->GetOutputConfidence();
// Progress reporting
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
......@@ -85,6 +104,7 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<ConfidenceImageType> ConfidenceMapIteratorType;
InputIteratorType inIt(inputPtr, outputRegionForThread);
OutputIteratorType outIt(outputPtr, outputRegionForThread);
......@@ -97,7 +117,17 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
maskIt.GoToBegin();
}
// setup iterator for confidence map
bool computeConfidenceMap(m_UseConfidenceMap && m_Model->HasConfidenceIndex());
ConfidenceMapIteratorType confidenceIt;
if (computeConfidenceMap)
{
confidenceIt = ConfidenceMapIteratorType(confidencePtr,outputRegionForThread);
confidenceIt.GoToBegin();
}
bool validPoint = true;
double confidenceIndex = 0.0;
// Walk the part of the image
for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
......@@ -112,12 +142,25 @@ ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
if (validPoint)
{
// Classifify
outIt.Set(m_Model->Predict(inIt.Get())[0]);
if (computeConfidenceMap)
{
outIt.Set(m_Model->Predict(inIt.Get(),&confidenceIndex)[0]);
}
else
{
outIt.Set(m_Model->Predict(inIt.Get())[0]);
}
}
else
{
// else, set default value
outIt.Set(m_DefaultLabel);
confidenceIndex = 0.0;
}
if (computeConfidenceMap)
{
confidenceIt.Set(confidenceIndex);
++confidenceIt;
}
progress.CompletedPixel();
}
......
......@@ -39,19 +39,17 @@ public:
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(KNearestNeighborsMachineLearningModel, itk::MachineLearningModel);
itkTypeMacro(KNearestNeighborsMachineLearningModel, MachineLearningModel);
/** Setters/Getters to the number of neighbors to use
* Default is 32
......@@ -95,7 +93,7 @@ protected:
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
private:
KNearestNeighborsMachineLearningModel(const Self &); //purposely not implemented
......
......@@ -35,6 +35,7 @@ KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
m_K(32),
m_IsRegression(false)
{
this->m_ConfidenceIndex = true;
}
......@@ -66,18 +67,36 @@ template <class TInputValue, class TTargetValue>
typename KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::TargetSampleType
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::PredictClassification(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
{
//convert listsample to Mat
cv::Mat sample;
otb::SampleToMat<InputSampleType>(input, sample);
double result = m_KNearestModel->find_nearest(sample, m_K);
float result;
if (quality != NULL)
{
cv::Mat nearest(1,m_K,CV_32FC1);
result = m_KNearestModel->find_nearest(sample, m_K,0,0,&nearest,0);
unsigned int accuracy = 0;
for (int k=0 ; k < m_K ; ++k)
{
if (nearest.at<float>(0,k) == result)
{
accuracy++;
}
}
(*quality) = static_cast<ConfidenceValueType>(accuracy);
}
else
{
result = m_KNearestModel->find_nearest(sample, m_K);
}
TargetSampleType target;
target[0] = static_cast<TTargetValue>(result);
return target;
}
......
......@@ -40,15 +40,13 @@ public:
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
// LibSVM related typedefs
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<InputSampleType> MeasurementVectorFunctorType;
......@@ -59,7 +57,7 @@ public:
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, itk::MachineLearningModel);
itkTypeMacro(SVMMachineLearningModel, MachineLearningModel);
/** Save the model to file */
virtual void Save(const std::string &filename, const std::string & name="");
......@@ -108,7 +106,7 @@ protected:
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
private:
LibSVMMachineLearningModel(const Self &); //purposely not implemented
......
......@@ -74,13 +74,15 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
m_SVMestimator->SetTrainingSampleList(this->GetTargetListSample());
m_SVMestimator->Update();
this->m_ConfidenceIndex = m_DoProbabilityEstimates;
}
template <class TInputValue, class TOutputValue>
typename LibSVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::PredictClassification(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
{
TargetSampleType target;
......@@ -88,6 +90,31 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
target = m_SVMestimator->GetModel()->EvaluateLabel(mfunctor(input));
if (quality != NULL)
{
if (!this->m_ConfidenceIndex)
{
itkExceptionMacro("Confidence index not available for this classifier !");
}
typename SVMEstimatorType::ModelType::ProbabilitiesVectorType probaVector =
m_SVMestimator->GetModel()->EvaluateProbabilities(mfunctor(input));
double maxProb = 0.0;
double secProb = 0.0;
for (unsigned int i=0 ; i<probaVector.Size() ; ++i)
{
if (maxProb < probaVector[i])
{
secProb = maxProb;
maxProb = probaVector[i];
}
else if (secProb < probaVector[i])
{
secProb = probaVector[i];
}
}
(*quality) = static_cast<ConfidenceValueType>(maxProb - secProb);
}
return target;
}
......@@ -105,6 +132,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Load(const std::string & filename, const std::string & itkNotUsed(name))
{
m_SVMestimator->GetModel()->LoadModel(filename.c_str());
this->m_ConfidenceIndex = m_SVMestimator->GetModel()->HasProbabilities();
}
template <class TInputValue, class TOutputValue>
......
......@@ -60,7 +60,7 @@ namespace otb
*
* \ingroup OTBSupervised
*/
template <class TInputValue, class TTargetValue>
template <class TInputValue, class TTargetValue, class TConfidenceValue = double >
class ITK_EXPORT MachineLearningModel
: public itk::Object
{
......@@ -87,6 +87,9 @@ public:
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
//@}
/**\name Confidence value typedef */
typedef TConfidenceValue ConfidenceValueType;
/**\name Standard macros */
//@{
/** Run-time type information (and related methods). */
......@@ -97,7 +100,7 @@ public:
void Train();
/** Predict values using the model */
TargetSampleType Predict(const InputSampleType& input) const;
TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = NULL) const;
/** Classify all samples in InputListSample and fill TargetListSample with the associated label */