From 938a62a9a466f667158a78c05eaba48a45d24bec Mon Sep 17 00:00:00 2001 From: Guillaume Pasero <guillaume.pasero@c-s.fr> Date: Mon, 10 Apr 2017 19:12:51 +0200 Subject: [PATCH] REFAC: support hyperplanes and proba output through the confidence index --- .../include/otbLibSVMMachineLearningModel.h | 25 ++++++ .../include/otbLibSVMMachineLearningModel.txx | 77 ++++++++++++------- 2 files changed, 74 insertions(+), 28 deletions(-) diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h index 33e4ee9b91..a79416de81 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h @@ -48,6 +48,17 @@ public: typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType; + /** enum to choose the way confidence is computed + * CM_INDEX : compute the difference between highest and second highest probability + * CM_PROBA : returns probabilities for all classes + * The given pointer needs to store 'nbClass' values + * This mode requires that ConfidenceValueType is double + * CM_HYPER : returns hyperplanes distances* + * The given pointer needs to store 'nbClass * (nbClass-1) / 2' values + * This mode requires that ConfidenceValueType is double + */ + typedef enum {CM_INDEX,CM_PROBA,CM_HYPER} ConfidenceMode; + /** Run-time type information (and related methods). */ itkNewMacro(Self); itkTypeMacro(SVMMachineLearningModel, MachineLearningModel); @@ -225,6 +236,17 @@ public: itkSetMacro(FineOptimizationNumberOfSteps, unsigned int); itkGetMacro(FineOptimizationNumberOfSteps, unsigned int); + void SetConfidenceMode(unsigned int mode) + { + if (mode != m_ConfidenceMode) + { + m_ConfidenceMode = mode; + this->m_ConfidenceIndex = this->HasProbabilities(); + this->Modified(); + } + } + itkGetMacro(ConfidenceMode, unsigned int); + unsigned int GetNumberOfKernelParameters(); double CrossValidation(void); @@ -283,6 +305,9 @@ private: /** Number of steps for the fine search */ unsigned int m_FineOptimizationNumberOfSteps; + /** Output mode for confidence index (see enum ) */ + ConfidenceMode m_ConfidenceMode; + /** Temporary array to store cross-validation results */ std::vector<double> m_TmpTarget; diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx index 6a10eb4934..dedfe1787e 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.txx @@ -53,6 +53,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> this->m_FinalCrossValidationAccuracy = 0.; this->m_CoarseOptimizationNumberOfSteps = 5; this->m_FineOptimizationNumberOfSteps = 5; + this->m_ConfidenceMode = + LibSVMMachineLearningModel<TInputValue,TOutputValue>::CM_INDEX; this->m_Parameters.nr_weight = 0; this->m_Parameters.weight_label = ITK_NULLPTR; @@ -129,38 +131,49 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue> { itkExceptionMacro("Confidence index not available for this classifier !"); } - if (svm_type == C_SVC || svm_type == NU_SVC) + if (this->m_ConfidenceMode == CM_INDEX) { - // Eventually allocate space for probabilities - int nr_class = svm_get_nr_class(m_Model); - double *prob_estimates = new double[nr_class]; - // predict - target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates)); - double maxProb = 0.0; - double secProb = 0.0; - for (unsigned int i=0 ; i< nr_class ; ++i) + if (svm_type == C_SVC || svm_type == NU_SVC) { - if (maxProb < prob_estimates[i]) + // Eventually allocate space for probabilities + int nr_class = svm_get_nr_class(m_Model); + double *prob_estimates = new double[nr_class]; + // predict + target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates)); + double maxProb = 0.0; + double secProb = 0.0; + for (unsigned int i=0 ; i< nr_class ; ++i) { - secProb = maxProb; - maxProb = prob_estimates[i]; + if (maxProb < prob_estimates[i]) + { + secProb = maxProb; + maxProb = prob_estimates[i]; + } + else if (secProb < prob_estimates[i]) + { + secProb = prob_estimates[i]; + } } - else if (secProb < prob_estimates[i]) - { - secProb = prob_estimates[i]; - } - } - (*quality) = static_cast<ConfidenceValueType>(maxProb - secProb); + (*quality) = static_cast<ConfidenceValueType>(maxProb - secProb); - delete[] prob_estimates; + delete[] prob_estimates; + } + else + { + target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x)); + // Prob. model for test data: target value = predicted value + z + // z: Laplace distribution e^(-|z|/sigma)/(2sigma) + // sigma is output as confidence index + (*quality) = svm_get_svr_probability(m_Model); + } + } + else if (this->m_ConfidenceMode == CM_PROBA) + { + target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, quality)); } - else + else if (this->m_ConfidenceMode == CM_HYPER) { - target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x)); - // Prob. model for test data: target value = predicted value + z - // z: Laplace distribution e^(-|z|/sigma)/(2sigma) - // sigma is output as confidence index - (*quality) = svm_get_svr_probability(m_Model); + target[0] = static_cast<TargetValueType>(svm_predict_values(m_Model, x, quality)); } } else @@ -252,10 +265,18 @@ bool LibSVMMachineLearningModel<TInputValue,TOutputValue> ::HasProbabilities(void) const { - bool ret = static_cast<bool>(svm_check_probability_model(m_Model)); - if (svm_get_svm_type(m_Model) == ONE_CLASS) + bool modelHasProba = static_cast<bool>(svm_check_probability_model(m_Model)); + int type = svm_get_svm_type(m_Model); + int cmMode = this->m_ConfidenceMode; + bool ret = false; + if (type == EPSILON_SVR || type == NU_SVR) + { + ret = (modelHasProba && cmMode == CM_INDEX); + } + else if (type == C_SVC || type == NU_SVC) { - ret = false; + ret = (modelHasProba && (cmMode == CM_INDEX || cmMode == CM_PROBA)) || + (cmMode == CM_HYPER); } return ret; } -- GitLab