diff --git a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h index eb73f1c3d78415c7913108c9a9deca3064a15c2d..15a8df438697b6bf5031eb73e1f3e4852f57af4d 100644 --- a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.h @@ -130,7 +130,7 @@ protected: ~GradientBoostedTreeMachineLearningModel() override; /** Predict values using the model */ - TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override; + TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr, ProbaSampleType *proba=nullptr) const override; /** PrintSelf method */ void PrintSelf(std::ostream& os, itk::Indent indent) const override; diff --git a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx index b5479efefa25144ed93c73c5375701d7784b3efe..132dd7bc3d12530260a2cc8f0ec2328775a94cbf 100644 --- a/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbGradientBoostedTreeMachineLearningModel.hxx @@ -103,7 +103,7 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue> itkExceptionMacro("Confidence index not available for this classifier !"); } } - if (proba != nullptr && !m_ProbaIndex) + if (proba != nullptr && !this->m_ProbaIndex) itkExceptionMacro("Probability per class not available for this classifier !"); return target; diff --git a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx index dcc010c02f239329a49d953fd87fc625ace82cfa..a78f668fcbff653f4016f5cdb7a98fa28760951d 100644 --- a/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx +++ b/Modules/Learning/Supervised/include/otbRandomForestsMachineLearningModel.hxx @@ -192,6 +192,10 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> else (*quality) = m_RFModel->predict_confidence(sample); } + + if (proba != nullptr && !this->m_ProbaIndex) + itkExceptionMacro("Probability per class not available for this classifier !"); + return target[0]; }