Commit 8e55ea32 authored by Guillaume Pasero's avatar Guillaume Pasero

REFAC: integrate SVMModel directly into LibSVMMachineLearningModel (WIP)

parent 2ae42cdd
......@@ -25,10 +25,7 @@
#include "itkFixedArray.h"
#include "otbMachineLearningModel.h"
// SVM estimator
#include "otbSVMSampleListModelEstimator.h"
// Validation
#include "otbSVMClassifier.h"
#include "svm.h"
namespace otb
{
......@@ -51,13 +48,6 @@ public:
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
// LibSVM related typedefs
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<InputSampleType> MeasurementVectorFunctorType;
typedef otb::SVMSampleListModelEstimator<InputListSampleType, TargetListSampleType, MeasurementVectorFunctorType>
SVMEstimatorType;
typedef otb::SVMClassifier<InputSampleType, TargetValueType> ClassifierType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, MachineLearningModel);
......@@ -80,41 +70,147 @@ public:
bool CanWriteFile(const std::string &) ITK_OVERRIDE;
//@}
//Setters/Getters to SVM model
otbGetObjectMemberMacro(SVMestimator, SVMType, int);
otbSetObjectMemberMacro(SVMestimator, SVMType, int);
#define otbSetSVMParameterMacro(name, alias, type) \
void Set##name (const type _arg) \
{ \
itkDebugMacro("setting " #name " to " << _arg); \
if ( this->m_Parameters.alias != _arg ) \
{ \
this->m_Parameters.alias = _arg; \
this->Modified(); \
} \
}
/** Set the SVM type to C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR */
otbSetSVMParameterMacro(SVMType, svm_type, int)
otbGetObjectMemberMacro(SVMestimator, KernelType, int);
otbSetObjectMemberMacro(SVMestimator, KernelType, int);
/** Get the SVM type (C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR) */
int GetSVMType(void) const
{
return m_Parameters.svm_type;
}
otbGetObjectMemberMacro(SVMestimator, C, double);
otbSetObjectMemberMacro(SVMestimator, C, double);
/** Set the kernel type to LINEAR, POLY, RBF, SIGMOID
linear: u'*v
polynomial: (gamma*u'*v + coef0)^degree
radial basis function: exp(-gamma*|u-v|^2)
sigmoid: tanh(gamma*u'*v + coef0)*/
otbSetSVMParameterMacro(KernelType, kernel_type, int)
// TODO : we should harmonize this parameter name : ParameterOptimization -> ParametersOptimization
bool GetParameterOptimization()
/** Get the kernel type */
int GetKernelType(void) const
{
return this->m_SVMestimator->GetParametersOptimization();
return m_Parameters.kernel_type;
}
void SetParameterOptimization(bool value)
/** Set the degree of the polynomial kernel */
otbSetSVMParameterMacro(PolynomialKernelDegree,degree,int)
/** Get the degree of the polynomial kernel */
int GetPolynomialKernelDegree(void) const
{
this->m_SVMestimator->SetParametersOptimization(value);
this->Modified();
return m_Parameters.degree;
}
/** Set the gamma parameter for poly/rbf/sigmoid kernels */
otbSetSVMParameterMacro(KernelGamma,gamma,double)
/** Get the gamma parameter for poly/rbf/sigmoid kernels */
double GetKernelGamma(void) const
{
return m_Parameters.gamma;
}
/** Set the coef0 parameter for poly/sigmoid kernels */
otbSetSVMParameterMacro(KernelCoef0,coef0,double)
/** Get the coef0 parameter for poly/sigmoid kernels */
double GetKernelCoef0(void) const
{
return m_Parameters.coef0;
}
/** Set the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
otbSetSVMParameterMacro(C,C,double)
/** Get the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
double GetC(void) const
{
return m_Parameters.C;
}
itkSetMacro(ParameterOptimization, bool);
itkGetMacro(ParameterOptimization, bool);
/** Do probability estimates */
void DoProbabilityEstimates(bool prob)
{
m_Parameters.probability = static_cast<int>(prob);
}
otbGetObjectMemberMacro(SVMestimator, DoProbabilityEstimates, bool);
void SetDoProbabilityEstimates(bool value)
/** Get Do probability estimates boolean */
bool GetDoProbabilityEstimates(void) const
{
this->m_SVMestimator->DoProbabilityEstimates(value);
return static_cast<bool>(m_Parameters.probability);
}
otbGetObjectMemberMacro(SVMestimator, Epsilon, double);
otbSetObjectMemberMacro(SVMestimator, Epsilon, double);
/** Test if the model has probabilities */
bool HasProbabilities(void) const;
otbGetObjectMemberMacro(SVMestimator, P, double);
otbSetObjectMemberMacro(SVMestimator, P, double);
/** Set the tolerance for the stopping criterion for the training*/
otbSetSVMParameterMacro(Epsilon,eps,double)
otbGetObjectMemberMacro(SVMestimator, Nu, double);
otbSetObjectMemberMacro(SVMestimator, Nu, double);
/** Get the tolerance for the stopping criterion for the training*/
double GetEpsilon(void) const
{
return m_Parameters.eps;
}
/** Set the value of p for EPSILON_SVR */
otbSetSVMParameterMacro(P,p,double)
/** Get the value of p for EPSILON_SVR */
double GetP(void) const
{
return m_Parameters.p;
}
/** Set the Nu parameter for the training */
otbSetSVMParameterMacro(Nu,nu,double)
/** Set the Nu parameter for the training */
double GetNu(void) const
{
return m_Parameters.nu;
}
#undef otbSetSVMParameterMacro
/** Use the shrinking heuristics for the training */
void DoShrinking(bool s)
{
m_Parameters.shrinking = static_cast<int>(s);
this->Modified();
}
/** Get Use the shrinking heuristics for the training boolea */
bool GetDoShrinking(void) const
{
return static_cast<bool>(m_Parameters.shrinking);
}
/** Set the cache size in MB for the training */
void SetCacheSize(int cSize)
{
m_Parameters.cache_size = static_cast<double>(cSize);
this->Modified();
}
/** Get the cache size in MB for the training */
int GetCacheSize(void) const
{
return static_cast<int>(m_Parameters.cache_size);
}
protected:
/** Constructor */
......@@ -133,7 +229,33 @@ private:
LibSVMMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
typename SVMEstimatorType::Pointer m_SVMestimator;
void BuildProblem(void);
void ConsistencyCheck(void);
void DeleteProblem(void);
void DeleteModel(void);
double CrossValidation(unsigned int nbFolders);
void OptimizeParameters(void);
/** Container to hold the SVM model itself */
struct svm_model* m_Model;
/** Structure that stores training vectors */
struct svm_problem m_Problem;
/** Container of the SVM parameters */
struct svm_parameter m_Parameters;
/** Do parameters optimization, default : false */
bool m_ParameterOptimization;
/** Temporary array to store cross-validation results */
std::vector<double> m_TmpTarget;
};
} // end namespace otb
......
......@@ -23,6 +23,7 @@
#include <fstream>
#include "otbLibSVMMachineLearningModel.h"
#include "otbMacro.h"
namespace otb
{
......@@ -31,22 +32,38 @@ template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::LibSVMMachineLearningModel()
{
m_SVMestimator = SVMEstimatorType::New();
m_SVMestimator->SetSVMType(C_SVC);
m_SVMestimator->SetC(1.0);
m_SVMestimator->SetKernelType(LINEAR);
m_SVMestimator->SetParametersOptimization(false);
m_SVMestimator->DoProbabilityEstimates(false);
//m_SVMestimator->SetEpsilon(1e-6);
this->SetSVMType(C_SVC);
this->SetKernelType(LINEAR);
this->SetPolynomialKernelDegree(3);
this->SetKernelGamma(1.); // 1/k
this->SetKernelCoef0(1.);
this->SetNu(0.5);
this->SetC(1.0);
this->SetEpsilon(1e-3);
this->SetP(0.1);
this->DoProbabilityEstimates(false);
this->DoShrinking(true);
this->SetCacheSize(40); // MB
this->m_ParameterOptimization = false;
this->m_IsRegressionSupported = true;
}
this->m_Parameters.nr_weight = 0;
this->m_Parameters.weight_label = ITK_NULLPTR;
this->m_Parameters.weight = ITK_NULLPTR;
this->m_Model = ITK_NULLPTR;
this->m_Problem.l = 0;
this->m_Problem.y = ITK_NULLPTR;
this->m_Problem.x = ITK_NULLPTR;
}
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::~LibSVMMachineLearningModel()
{
//delete m_SVMModel;
this->DeleteModel();
this->DeleteProblem();
}
/** Train the machine learning model */
......@@ -55,19 +72,22 @@ void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
// Set up SVM's parameters
// CvSVMParams params;
// params.svm_type = m_SVMType;
// params.kernel_type = m_KernelType;
// params.term_crit = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
this->DeleteProblem();
this->DeleteModel();
// Build problem
this->BuildProblem();
// // Train the SVM
m_SVMestimator->SetInputSampleList(this->GetInputListSample());
m_SVMestimator->SetTrainingSampleList(this->GetTargetListSample());
// Check consistency
this->ConsistencyCheck();
m_SVMestimator->Update();
// Compute accuracy and eventually optimize parameters
this->OptimizeParameters();
this->m_ConfidenceIndex = this->GetDoProbabilityEstimates();
// train the model
m_Model = svm_train(&m_Problem, &m_Parameters);
this->m_ConfidenceIndex = this->HasProbabilities();
}
template <class TInputValue, class TOutputValue>
......@@ -78,9 +98,23 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
{
TargetSampleType target;
MeasurementVectorFunctorType mfunctor;
// Get type and number of classes
int svm_type = svm_get_svm_type(m_Model);
// Allocate nodes (/TODO if performances problems are related to too
// many allocations, a cache approach can be set)
struct svm_node * x = new struct svm_node[input.Size() + 1];
// Fill the node
for (int i = 0 ; i < input.Size() ; i++)
{
x[i].index = i + 1;
x[i].value = input[i];
}
target = m_SVMestimator->GetModel()->EvaluateLabel(mfunctor(input));
// terminate node
x[input.Size()].index = -1;
x[input.Size()].value = 0;
if (quality != ITK_NULLPTR)
{
......@@ -88,25 +122,48 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
{
itkExceptionMacro("Confidence index not available for this classifier !");
}
typename SVMEstimatorType::ModelType::ProbabilitiesVectorType probaVector =
m_SVMestimator->GetModel()->EvaluateProbabilities(mfunctor(input));
double maxProb = 0.0;
double secProb = 0.0;
for (unsigned int i=0 ; i<probaVector.Size() ; ++i)
if (svm_type == C_SVC || svm_type == NU_SVC)
{
if (maxProb < probaVector[i])
// Eventually allocate space for probabilities
int nr_class = svm_get_nr_class(m_Model);
double *prob_estimates = new double[nr_class];
// predict
target[0] = static_cast<TargetValueType>(svm_predict_probability(m_Model, x, prob_estimates));
double maxProb = 0.0;
double secProb = 0.0;
for (unsigned int i=0 ; i< nr_class ; ++i)
{
if (maxProb < prob_estimates[i])
{
secProb = maxProb;
maxProb = probaVector[i];
}
else if (secProb < probaVector[i])
{
secProb = probaVector[i];
maxProb = prob_estimates[i];
}
else if (secProb < prob_estimates[i])
{
secProb = prob_estimates[i];
}
}
(*quality) = static_cast<ConfidenceValueType>(maxProb - secProb);
delete[] prob_estimates;
}
(*quality) = static_cast<ConfidenceValueType>(maxProb - secProb);
else
{
target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x));
// Prob. model for test data: target value = predicted value + z
// z: Laplace distribution e^(-|z|/sigma)/(2sigma)
// sigma is output as confidence index
(*quality) = svm_get_svr_probability(m_Model);
}
}
else
{
target[0] = static_cast<TargetValueType>(svm_predict(m_Model, x));
}
// Free allocated memory
delete[] x;
return target;
}
......@@ -115,7 +172,10 @@ void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Save(const std::string & filename, const std::string & itkNotUsed(name))
{
m_SVMestimator->GetModel()->SaveModel(filename.c_str());
if (svm_save_model(filename.c_str(), m_Model) != 0)
{
itkExceptionMacro(<< "Problem while saving SVM model " << filename);
}
}
template <class TInputValue, class TOutputValue>
......@@ -123,9 +183,15 @@ void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Load(const std::string & filename, const std::string & itkNotUsed(name))
{
m_SVMestimator->GetModel()->LoadModel(filename.c_str());
this->DeleteModel();
m_Model = svm_load_model(filename.c_str());
if (m_Model == ITK_NULLPTR)
{
itkExceptionMacro(<< "Problem while loading SVM model " << filename);
}
m_Parameters = m_Model->param;
this->m_ConfidenceIndex = m_SVMestimator->GetModel()->HasProbabilities();
this->m_ConfidenceIndex = this->HasProbabilities();
}
template <class TInputValue, class TOutputValue>
......@@ -174,6 +240,300 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
Superclass::PrintSelf(os,indent);
}
template <class TInputValue, class TOutputValue>
bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::HasProbabilities(void) const
{
bool ret = static_cast<bool>(svm_check_probability_model(m_Model));
if (svm_get_svm_type(m_Model) == ONE_CLASS)
{
ret = false;
}
return ret;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::BuildProblem()
{
// Get number of samples
typename InputListSampleType::Pointer samples = this->GetInputListSample();
typename TargetListSampleType::Pointer target = this->GetTargetListSample();
int probl = samples->Size();
if (probl < 1)
{
itkExceptionMacro(<< "No samples, can not build SVM problem.");
}
otbMsgDebugMacro(<< "Building problem ...");
// Get the size of the samples
long int elements = samples->GetMeasurementVectorSize();
// Allocate the problem
m_Problem.l = probl;
m_Problem.y = new double[probl];
m_Problem.x = new struct svm_node*[probl];
for (int i = 0; i < probl; ++i)
{
m_Problem.x[i] = new struct svm_node[elements+1];
}
// Iterate on the samples
typename InputListSampleType::ConstIterator sIt = samples->Begin();
typename TargetListSampleType::ConstIterator tIt = target->Begin();
int sampleIndex = 0;
while (sIt != samples->End() && tIt != target->End())
{
// Set the label
m_Problem.y[sampleIndex] = tIt.GetMeasurementVector()[0];
const InputSampleType &sample = sIt.GetMeasurementVector();
for (int k = 0 ; k < elements ; ++k)
{
m_Problem.x[sampleIndex][k].index = k + 1;
m_Problem.x[sampleIndex][k].value = sample[k];
}
// terminate node
m_Problem.x[sampleIndex][elements].index = -1;
m_Problem.x[sampleIndex][elements].value = 0;
++sampleIndex;
++sIt;
++tIt;
}
// Compute the kernel gamma from number of elements if necessary
if (this->GetKernelGamma() == 0)
{
this->SetKernelGamma(1.0 / static_cast<double>(elements));
}
// allocate buffer for cross validation
m_TmpTarget.resize(m_Problem.l);
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::ConsistencyCheck()
{
if (this->GetSVMType() == ONE_CLASS && this->GetDoProbabilityEstimates())
{
otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
this->DoProbabilityEstimates(false);
}
const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
if (error_msg)
{
std::string err(error_msg);
itkExceptionMacro("SVM parameter check failed : " << err);
}
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::DeleteProblem()
{
if (m_Problem.y)
{
delete[] m_Problem.y;
m_Problem.y = ITK_NULLPTR;
}
if (m_Problem.x)
{
for (int i = 0; i < m_Problem.l; ++i)
{
if (m_Problem.x[i])
{
delete[] m_Problem.x[i];
}
}
delete[] m_Problem.x;
m_Problem.x = ITK_NULLPTR;
}
m_Problem.l = 0;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::DeleteModel(void)
{
if(m_Model)
{
svm_free_and_destroy_model(&m_Model);
}
m_Model = ITK_NULLPTR;
}
template <class TInputValue, class TOutputValue>
double
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CrossValidation(unsigned int nbFolders)
{
double accuracy = 0.0;
// Get the length of the problem
int length = m_Problem.l;
if (length == 0 || m_TmpTarget.size() < length )
return accuracy;
// Do cross validation
svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, &m_TmpTarget[0]);
// Evaluate accuracy
double total_correct = 0.;
for (int i = 0; i < length; ++i)
{
if (target[i] == m_Problem.y[i])
{
++total_correct;
}
}
accuracy = total_correct / length;
// return accuracy value
return accuracy;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::OptimizeParameters()
{
typedef SVMCrossValidationCostFunction<SVMModelType> CrossValidationFunctionType;
typename CrossValidationFunctionType::Pointer crossValidationFunction = CrossValidationFunctionType::New();
crossValidationFunction->SetModel(this->GetModel());
crossValidationFunction->SetNumberOfCrossValidationFolders(m_NumberOfCrossValidationFolders);
typename CrossValidationFunctionType::ParametersType initialParameters, coarseBestParameters, fineBestParameters;
switch (this->GetKernelType())