Skip to content
Snippets Groups Projects
Commit 938a62a9 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

REFAC: support hyperplanes and proba output through the confidence index

parent da152f58
No related branches found
No related tags found
No related merge requests found
......@@ -48,6 +48,17 @@ public:
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
/** enum to choose the way confidence is computed
* CM_INDEX : compute the difference between highest and second highest probability
* CM_PROBA : returns probabilities for all classes
* The given pointer needs to store 'nbClass' values
* This mode requires that ConfidenceValueType is double
* CM_HYPER : returns hyperplanes distances*
* The given pointer needs to store 'nbClass * (nbClass-1) / 2' values
* This mode requires that ConfidenceValueType is double
*/
typedef enum {CM_INDEX,CM_PROBA,CM_HYPER} ConfidenceMode;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, MachineLearningModel);
......@@ -225,6 +236,17 @@ public:
itkSetMacro(FineOptimizationNumberOfSteps, unsigned int);
itkGetMacro(FineOptimizationNumberOfSteps, unsigned int);
void SetConfidenceMode(unsigned int mode)
{
if (mode != m_ConfidenceMode)
{
m_ConfidenceMode = mode;
this->m_ConfidenceIndex = this->HasProbabilities();
this->Modified();
}
}
itkGetMacro(ConfidenceMode, unsigned int);
unsigned int GetNumberOfKernelParameters();
double CrossValidation(void);
......@@ -283,6 +305,9 @@ private:
/** Number of steps for the fine search */
unsigned int m_FineOptimizationNumberOfSteps;
/** Output mode for confidence index (see enum ) */
ConfidenceMode m_ConfidenceMode;
/** Temporary array to store cross-validation results */
std::vector<double> m_TmpTarget;
......
......@@ -53,6 +53,8 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
this->m_FinalCrossValidationAccuracy = 0.;
this->m_CoarseOptimizationNumberOfSteps = 5;
this->m_FineOptimizationNumberOfSteps = 5;
this->m_ConfidenceMode =
LibSVMMachineLearningModel<TInputValue,TOutputValue>::CM_INDEX;
this->m_Parameters.nr_weight = 0;
this->m_Parameters.weight_label = ITK_NULLPTR;
......@@ -129,38 +131,49 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
{
itkExceptionMacro("Confidence index not available for this classifier !");
}
if (svm_type == C_SVC || svm_type == NU_SVC)
if (this->m_ConfidenceMode == CM_INDEX)
{
// Eventually allocate space for probabilities
int nr_class = svm_get_nr_class(m_Model);
double *prob_estimates = new double[nr_class];
// predict
target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
double maxProb = 0.0;
double secProb = 0.0;
for (unsigned int i=0 ; i< nr_class ; ++i)
if (svm_type == C_SVC || svm_type == NU_SVC)
{
if (maxProb < prob_estimates[i])
// Eventually allocate space for probabilities
int nr_class = svm_get_nr_class(m_Model);
double *prob_estimates = new double[nr_class];
// predict
target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
double maxProb = 0.0;
double secProb = 0.0;
for (unsigned int i=0 ; i< nr_class ; ++i)
{
secProb = maxProb;
maxProb = prob_estimates[i];
if (maxProb < prob_estimates[i])
{
secProb = maxProb;
maxProb = prob_estimates[i];
}
else if (secProb < prob_estimates[i])
{
secProb = prob_estimates[i];
}
}
else if (secProb < prob_estimates[i])
{
secProb = prob_estimates[i];
}
}
(*quality) = static_cast<ConfidenceValueType>(maxProb - secProb);
(*quality) = static_cast<ConfidenceValueType>(maxProb - secProb);
delete[] prob_estimates;
delete[] prob_estimates;
}
else
{
target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x));
// Prob. model for test data: target value = predicted value + z
// z: Laplace distribution e^(-|z|/sigma)/(2sigma)
// sigma is output as confidence index
(*quality) = svm_get_svr_probability(m_Model);
}
}
else if (this->m_ConfidenceMode == CM_PROBA)
{
target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, quality));
}
else
else if (this->m_ConfidenceMode == CM_HYPER)
{
target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x));
// Prob. model for test data: target value = predicted value + z
// z: Laplace distribution e^(-|z|/sigma)/(2sigma)
// sigma is output as confidence index
(*quality) = svm_get_svr_probability(m_Model);
target[0] = static_cast<TargetValueType>(svm_predict_values(m_Model, x, quality));
}
}
else
......@@ -252,10 +265,18 @@ bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::HasProbabilities(void) const
{
bool ret = static_cast<bool>(svm_check_probability_model(m_Model));
if (svm_get_svm_type(m_Model) == ONE_CLASS)
bool modelHasProba = static_cast<bool>(svm_check_probability_model(m_Model));
int type = svm_get_svm_type(m_Model);
int cmMode = this->m_ConfidenceMode;
bool ret = false;
if (type == EPSILON_SVR || type == NU_SVR)
{
ret = (modelHasProba && cmMode == CM_INDEX);
}
else if (type == C_SVC || type == NU_SVC)
{
ret = false;
ret = (modelHasProba && (cmMode == CM_INDEX || cmMode == CM_PROBA)) ||
(cmMode == CM_HYPER);
}
return ret;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment