diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx index e9f9346e3c284ce6a7be0a967afca816eeb072af..eabdc21d31b5e009c5b56541d57188002c4194f5 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx @@ -74,6 +74,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> m_SVMestimator->SetTrainingSampleList(this->GetTargetListSample()); m_SVMestimator->Update(); + + this->m_ConfidenceIndex = m_DoProbabilityEstimates; } template <class TInputValue, class TOutputValue> @@ -96,6 +98,23 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> { itkExceptionMacro("Confidence index not available for this classifier !"); } + typename SVMEstimatorType::ModelType::ProbabilitiesVectorType probaVector = + m_SVMestimator->GetModel()->EvaluateProbabilities(mfunctor(input)); + double maxProb = 0.0; + double secProb = 0.0; + for (unsigned int i=0 ; i<probaVector.Size() ; ++i) + { + if (maxProb < probaVector[i]) + { + secProb = maxProb; + maxProb = probaVector[i]; + } + else if (secProb < probaVector[i]) + { + secProb = probaVector[i]; + } + } + (*quality) = static_cast<ConfidenceValueType>(maxProb - secProb); } return target; @@ -115,6 +134,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> ::Load(const std::string & filename, const std::string & itkNotUsed(name)) { m_SVMestimator->GetModel()->LoadModel(filename.c_str()); + + this->m_ConfidenceIndex = m_SVMestimator->GetModel()->HasProbabilities(); } template <class TInputValue, class TOutputValue>