Skip to content
Snippets Groups Projects
Commit 3d4f6bdd authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

RM: remove otb::SVMModel (has been ported into LibSVMMachineLearningModel)

parent cbd4bc0f
No related branches found
No related tags found
No related merge requests found
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSVMModel_h
#define otbSVMModel_h
#include "itkObjectFactory.h"
#include "itkDataObject.h"
#include "itkVariableLengthVector.h"
#include "itkTimeProbe.h"
#include "svm.h"
namespace otb
{
/** \class SVMModel
* \brief Class for SVM models.
*
* \TODO update documentation
*
* The basic functionality of the SVMModel framework base class is to
* generate the models used in SVM classification. It requires input
* images and a training image to be provided by the user.
* This object supports data handling of multiband images. The object
* accepts the input image in vector format only, where each pixel is a
* vector and each element of the vector corresponds to an entry from
* 1 particular band of a multiband dataset. A single band image is treated
* as a vector image with a single element for every vector. The classified
* image is treated as a single band scalar image.
*
* A membership function represents a specific knowledge about
* a class. In other words, it should tell us how "likely" is that a
* measurement vector (pattern) belong to the class.
*
* As the method name indicates, you can have more than one membership
* function. One for each classes. The order you put the membership
* calculator becomes the class label for the class that is represented
* by the membership calculator.
*
*
* \ingroup ClassificationFilters
*
* \ingroup OTBSVMLearning
*/
template <class TValue, class TLabel>
class ITK_EXPORT SVMModel : public itk::DataObject
{
public:
/** Standard class typedefs. */
typedef SVMModel Self;
typedef itk::DataObject Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Value type */
typedef TValue ValueType;
/** Label Type */
typedef TLabel LabelType;
typedef std::vector<ValueType> MeasurementType;
typedef std::pair<MeasurementType, LabelType> SampleType;
typedef std::vector<SampleType> SamplesVectorType;
/** Cache vector type */
typedef std::vector<struct svm_node *> CacheVectorType;
/** Distances vector */
typedef itk::VariableLengthVector<double> ProbabilitiesVectorType;
typedef itk::VariableLengthVector<double> DistancesVectorType;
typedef struct svm_node * NodeCacheType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMModel, itk::DataObject);
/** Get the number of classes. */
unsigned int GetNumberOfClasses(void) const
{
if (m_Model) return (unsigned int) (m_Model->nr_class);
return 0;
}
/** Get the number of hyperplane. */
unsigned int GetNumberOfHyperplane(void) const
{
if (m_Model) return (unsigned int) (m_Model->nr_class * (m_Model->nr_class - 1) / 2);
return 0;
}
/** Gets the model */
const struct svm_model* GetModel()
{
return m_Model;
}
/** Gets the parameters */
struct svm_parameter& GetParameters()
{
return m_Parameters;
}
/** Gets the parameters */
const struct svm_parameter& GetParameters() const
{
return m_Parameters;
}
/** Saves the model to a file */
void SaveModel(const char* model_file_name) const;
void SaveModel(const std::string& model_file_name) const
{
//implemented in term of const char * version
this->SaveModel(model_file_name.c_str());
}
/** Loads the model from a file */
void LoadModel(const char* model_file_name);
void LoadModel(const std::string& model_file_name)
{
//implemented in term of const char * version
this->LoadModel(model_file_name.c_str());
}
/** Set the SVM type to C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR */
void SetSVMType(int svmtype)
{
m_Parameters.svm_type = svmtype;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the SVM type (C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR) */
int GetSVMType(void) const
{
return m_Parameters.svm_type;
}
/** Set the kernel type to LINEAR, POLY, RBF, SIGMOID
linear: u'*v
polynomial: (gamma*u'*v + coef0)^degree
radial basis function: exp(-gamma*|u-v|^2)
sigmoid: tanh(gamma*u'*v + coef0)*/
void SetKernelType(int kerneltype)
{
m_Parameters.kernel_type = kerneltype;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the kernel type */
int GetKernelType(void) const
{
return m_Parameters.kernel_type;
}
/** Set the degree of the polynomial kernel */
void SetPolynomialKernelDegree(int degree)
{
m_Parameters.degree = degree;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the degree of the polynomial kernel */
int GetPolynomialKernelDegree(void) const
{
return m_Parameters.degree;
}
/** Set the gamma parameter for poly/rbf/sigmoid kernels */
virtual void SetKernelGamma(double gamma)
{
m_Parameters.gamma = gamma;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the gamma parameter for poly/rbf/sigmoid kernels */
double GetKernelGamma(void) const
{
return m_Parameters.gamma;
}
/** Set the coef0 parameter for poly/sigmoid kernels */
void SetKernelCoef0(double coef0)
{
m_Parameters.coef0 = coef0;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the coef0 parameter for poly/sigmoid kernels */
double GetKernelCoef0(void) const
{
//return m_Parameters.coef0;
return m_Parameters.coef0;
}
/** Set the Nu parameter for the training */
void SetNu(double nu)
{
m_Parameters.nu = nu;
m_ModelUpToDate = false;
this->Modified();
}
/** Set the Nu parameter for the training */
double GetNu(void) const
{
//return m_Parameters.nu;
return m_Parameters.nu;
}
/** Set the cache size in MB for the training */
void SetCacheSize(int cSize)
{
m_Parameters.cache_size = static_cast<double>(cSize);
m_ModelUpToDate = false;
this->Modified();
}
/** Get the cache size in MB for the training */
int GetCacheSize(void) const
{
return static_cast<int>(m_Parameters.cache_size);
}
/** Set the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
void SetC(double c)
{
m_Parameters.C = c;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the C parameter for the training for C_SVC, EPSILON_SVR and NU_SVR */
double GetC(void) const
{
return m_Parameters.C;
}
/** Set the tolerance for the stopping criterion for the training*/
void SetEpsilon(double eps)
{
m_Parameters.eps = eps;
m_ModelUpToDate = false;
this->Modified();
}
/** Get the tolerance for the stopping criterion for the training*/
double GetEpsilon(void) const
{
return m_Parameters.eps;
}
/* Set the value of p for EPSILON_SVR */
void SetP(double p)
{
m_Parameters.p = p;
m_ModelUpToDate = false;
this->Modified();
}
/* Get the value of p for EPSILON_SVR */
double GetP(void) const
{
return m_Parameters.p;
}
/** Use the shrinking heuristics for the training */
void DoShrinking(bool s)
{
m_Parameters.shrinking = static_cast<int>(s);
m_ModelUpToDate = false;
this->Modified();
}
/** Get Use the shrinking heuristics for the training boolea */
bool GetDoShrinking(void) const
{
return static_cast<bool>(m_Parameters.shrinking);
}
/** Do probability estimates */
void DoProbabilityEstimates(bool prob)
{
m_Parameters.probability = static_cast<int>(prob);
m_ModelUpToDate = false;
this->Modified();
}
/** Get Do probability estimates boolean */
bool GetDoProbabilityEstimates(void) const
{
return static_cast<bool>(m_Parameters.probability);
}
/** Test if the model has probabilities */
bool HasProbabilities(void) const
{
return static_cast<bool>(svm_check_probability_model(m_Model));
}
/** Return number of support vectors */
int GetNumberOfSupportVectors(void) const
{
if (m_Model) return m_Model->l;
return 0;
}
/** Return rho values */
double * GetRho(void) const
{
if (m_Model) return m_Model->rho;
return ITK_NULLPTR;
}
/** Return the support vectors */
svm_node ** GetSupportVectors(void)
{
if (m_Model) return m_Model->SV;
return ITK_NULLPTR;
}
/** Set the support vectors and changes the l number of support vectors accordind to sv.*/
void SetSupportVectors(svm_node ** sv, int nbOfSupportVector);
/** Return the alphas values (SV Coef) */
double ** GetAlpha(void)
{
if (m_Model) return m_Model->sv_coef;
return ITK_NULLPTR;
}
/** Set the alphas values (SV Coef) */
void SetAlpha(double ** alpha, int nbOfSupportVector);
/** Return the labels lists */
int * GetLabels()
{
if (m_Model) return m_Model->label;
return ITK_NULLPTR;
}
/** Get the number of SV per classes */
int * GetNumberOfSVPerClasse()
{
if (m_Model) return m_Model->nSV;
return ITK_NULLPTR;
}
struct svm_problem& GetProblem()
{
return m_Problem;
}
/** Allocate the problem */
void BuildProblem();
/** Check consistency (potentially throws exception) */
void ConsistencyCheck();
/** Estimate the model */
void Train();
/** Cross validation (returns the accuracy) */
double CrossValidation(unsigned int nbFolders);
/** Predict (Please note that due to caching this method is not
* thread safe. If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)*/
LabelType EvaluateLabel(const MeasurementType& measure) const;
/** Evaluate hyperplan distances (Please note that due to caching this method is not
* thread safe. If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)**/
DistancesVectorType EvaluateHyperplanesDistances(const MeasurementType& measure) const;
/** Evaluate probabilities of each class. Returns a probability vector ordered
* by increasing class label value
* (Please note that due to caching this method is not thread safe.
* If you want to run multiple concurrent instances of
* this method, please consider using the GetCopy() method to clone the
* model.)**/
ProbabilitiesVectorType EvaluateProbabilities(const MeasurementType& measure) const;
/** Add a new sample to the list */
void AddSample(const MeasurementType& measure, const LabelType& label);
/** Clear all samples */
void ClearSamples();
/** Set the samples vector */
void SetSamples(const SamplesVectorType& samples);
/** Reset all the model, leaving it in the same state that just
* before constructor call */
void Reset();
protected:
/** Constructor */
SVMModel();
/** Destructor */
~SVMModel() ITK_OVERRIDE;
/** Display infos */
void PrintSelf(std::ostream& os, itk::Indent indent) const ITK_OVERRIDE;
/** Delete any allocated problem */
void DeleteProblem();
/** Delete any allocated model */
void DeleteModel();
/** Initializes default parameters */
void Initialize() ITK_OVERRIDE;
private:
SVMModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
/** Container to hold the SVM model itself */
struct svm_model* m_Model;
/** True if model is up-to-date */
mutable bool m_ModelUpToDate;
/** Container of the SVM problem */
struct svm_problem m_Problem;
/** Container of the SVM parameters */
struct svm_parameter m_Parameters;
/** true if problem is up-to-date */
bool m_ProblemUpToDate;
/** Contains the samples */
SamplesVectorType m_Samples;
}; // class SVMModel
} // namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSVMModel.txx"
#endif
#endif
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSVMModel_txx
#define otbSVMModel_txx
#include "otbSVMModel.h"
#include "otbMacro.h"
#include <algorithm>
namespace otb
{
// TODO: Check memory allocation in this class
template <class TValue, class TLabel>
SVMModel<TValue, TLabel>::SVMModel()
{
// Default parameters
this->SetSVMType(C_SVC);
this->SetKernelType(LINEAR);
this->SetPolynomialKernelDegree(3);
this->SetKernelGamma(1.); // 1/k
this->SetKernelCoef0(1.);
this->SetNu(0.5);
this->SetCacheSize(40);
this->SetC(1);
this->SetEpsilon(1e-3);
this->SetP(0.1);
this->DoShrinking(true);
this->DoProbabilityEstimates(false);
m_Parameters.nr_weight = 0;
m_Parameters.weight_label = ITK_NULLPTR;
m_Parameters.weight = ITK_NULLPTR;
m_Model = ITK_NULLPTR;
this->Initialize();
}
template <class TValue, class TLabel>
SVMModel<TValue, TLabel>::~SVMModel()
{
this->DeleteModel();
this->DeleteProblem();
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::Initialize()
{
// Initialize model
/*
if (!m_Model)
{
m_Model = new struct svm_model;
m_Model->l = 0;
m_Model->nr_class = 0;
m_Model->SV = NULL;
m_Model->sv_coef = NULL;
m_Model->rho = NULL;
m_Model->label = NULL;
m_Model->probA = NULL;
m_Model->probB = NULL;
m_Model->nSV = NULL;
m_ModelUpToDate = false;
} */
m_ModelUpToDate = false;
// Initialize problem
m_Problem.l = 0;
m_Problem.y = ITK_NULLPTR;
m_Problem.x = ITK_NULLPTR;
m_ProblemUpToDate = false;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::Reset()
{
this->DeleteProblem();
this->DeleteModel();
// Clear samples
m_Samples.clear();
// Initialize values
this->Initialize();
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::DeleteModel()
{
if(m_Model)
{
svm_free_and_destroy_model(&m_Model);
}
m_Model = ITK_NULLPTR;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::DeleteProblem()
{
// Deallocate any existing problem
if (m_Problem.y)
{
delete[] m_Problem.y;
m_Problem.y = ITK_NULLPTR;
}
if (m_Problem.x)
{
for (int i = 0; i < m_Problem.l; ++i)
{
if (m_Problem.x[i])
{
delete[] m_Problem.x[i];
}
}
delete[] m_Problem.x;
m_Problem.x = ITK_NULLPTR;
}
m_Problem.l = 0;
m_ProblemUpToDate = false;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::AddSample(const MeasurementType& measure, const LabelType& label)
{
SampleType newSample(measure, label);
m_Samples.push_back(newSample);
m_ProblemUpToDate = false;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::ClearSamples()
{
m_Samples.clear();
m_ProblemUpToDate = false;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::SetSamples(const SamplesVectorType& samples)
{
m_Samples = samples;
m_ProblemUpToDate = false;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::BuildProblem()
{
// Check if problem is up-to-date
if (m_ProblemUpToDate)
{
return;
}
// Get number of samples
int probl = m_Samples.size();
if (probl < 1)
{
itkExceptionMacro(<< "No samples, can not build SVM problem.");
}
otbMsgDebugMacro(<< "Rebuilding problem ...");
// Get the size of the samples
long int elements = m_Samples[0].first.size() + 1;
// Deallocate any previous problem
this->DeleteProblem();
// Allocate the problem
m_Problem.l = probl;
m_Problem.y = new double[probl];
m_Problem.x = new struct svm_node*[probl];
for (int i = 0; i < probl; ++i)
{
// Initialize labels to 0
m_Problem.y[i] = 0;
m_Problem.x[i] = new struct svm_node[elements];
// Initialize elements (value = 0; index = -1)
for (unsigned int j = 0; j < static_cast<unsigned int>(elements); ++j)
{
m_Problem.x[i][j].index = -1;
m_Problem.x[i][j].value = 0;
}
}
// Iterate on the samples
typename SamplesVectorType::const_iterator sIt = m_Samples.begin();
int sampleIndex = 0;
int maxElementIndex = 0;
while (sIt != m_Samples.end())
{
// Get the sample measurement and label
MeasurementType measure = sIt->first;
LabelType label = sIt->second;
// Set the label
m_Problem.y[sampleIndex] = label;
int elementIndex = 0;
// Populate the svm nodes
for (typename MeasurementType::const_iterator eIt = measure.begin();
eIt != measure.end() && elementIndex < elements; ++eIt, ++elementIndex)
{
m_Problem.x[sampleIndex][elementIndex].index = elementIndex + 1;
m_Problem.x[sampleIndex][elementIndex].value = (*eIt);
}
// Get the max index
if (elementIndex > maxElementIndex)
{
maxElementIndex = elementIndex;
}
++sampleIndex;
++sIt;
}
// Compute the kernel gamma from maxElementIndex if necessary
if (this->GetKernelGamma() == 0)
{
this->SetKernelGamma(1.0 / static_cast<double>(maxElementIndex));
}
// problem is up-to-date
m_ProblemUpToDate = true;
}
template <class TValue, class TLabel>
double
SVMModel<TValue, TLabel>::CrossValidation(unsigned int nbFolders)
{
// Build problem
this->BuildProblem();
// Check consistency
this->ConsistencyCheck();
// Get the length of the problem
int length = m_Problem.l;
// Temporary memory to store cross validation results
double *target = new double[length];
// Do cross validation
svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, target);
// Evaluate accuracy
int i;
double total_correct = 0.;
for (i = 0; i < length; ++i)
{
if (target[i] == m_Problem.y[i])
{
++total_correct;
}
}
double accuracy = total_correct / length;
// Free temporary memory
delete[] target;
// return accuracy value
return accuracy;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::ConsistencyCheck()
{
if (m_Parameters.svm_type == ONE_CLASS && this->GetDoProbabilityEstimates())
{
otbMsgDebugMacro(<< "Disabling SVM probability estimates for ONE_CLASS SVM type.");
this->DoProbabilityEstimates(false);
}
const char* error_msg = svm_check_parameter(&m_Problem, &m_Parameters);
if (error_msg)
{
throw itk::ExceptionObject(__FILE__, __LINE__, error_msg, ITK_LOCATION);
}
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::Train()
{
// If the model is already up-to-date, return
if (m_ModelUpToDate)
{
return;
}
// Build problem
this->BuildProblem();
// Check consistency
this->ConsistencyCheck();
// train the model
m_Model = svm_train(&m_Problem, &m_Parameters);
// Set the model as up-to-date
m_ModelUpToDate = true;
}
template <class TValue, class TLabel>
typename SVMModel<TValue, TLabel>::LabelType
SVMModel<TValue, TLabel>::EvaluateLabel(const MeasurementType& measure) const
{
// Check if model is up-to-date
if (!m_ModelUpToDate)
{
itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
}
// Check probability prediction
bool predict_probability = svm_check_probability_model(m_Model);
if (this->GetSVMType() == ONE_CLASS)
{
predict_probability = 0;
}
// Get type and number of classes
int svm_type = svm_get_svm_type(m_Model);
int nr_class = svm_get_nr_class(m_Model);
// Allocate space for labels
double *prob_estimates = ITK_NULLPTR;
// Eventually allocate space for probabilities
if (predict_probability)
{
if (svm_type == NU_SVR || svm_type == EPSILON_SVR)
{
otbMsgDevMacro(
<<
"Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma), sigma="
<< svm_get_svr_probability(m_Model));
}
else
{
prob_estimates = new double[nr_class];
}
}
// Allocate nodes (/TODO if performances problems are related to too
// many allocations, a cache approach can be set)
struct svm_node * x = new struct svm_node[measure.size() + 1];
int valueIndex = 0;
// Fill the node
for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
{
x[valueIndex].index = valueIndex + 1;
x[valueIndex].value = (*mIt);
}
// terminate node
x[measure.size()].index = -1;
x[measure.size()].value = 0;
LabelType label = 0;
if (predict_probability && (svm_type == C_SVC || svm_type == NU_SVC))
{
label = static_cast<LabelType>(svm_predict_probability(m_Model, x, prob_estimates));
}
else
{
label = static_cast<LabelType>(svm_predict(m_Model, x));
}
// Free allocated memory
delete[] x;
if (prob_estimates)
{
delete[] prob_estimates;
}
return label;
}
template <class TValue, class TLabel>
typename SVMModel<TValue, TLabel>::DistancesVectorType
SVMModel<TValue, TLabel>::EvaluateHyperplanesDistances(const MeasurementType& measure) const
{
// Check if model is up-to-date
if (!m_ModelUpToDate)
{
itkExceptionMacro(<< "Model is not up-to-date, can not predict label");
}
// Allocate nodes (/TODO if performances problems are related to too
// many allocations, a cache approach can be set)
struct svm_node * x = new struct svm_node[measure.size() + 1];
int valueIndex = 0;
// Fill the node
for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
{
x[valueIndex].index = valueIndex + 1;
x[valueIndex].value = (*mIt);
}
// terminate node
x[measure.size()].index = -1;
x[measure.size()].value = 0;
// Initialize distances vector
DistancesVectorType distances(m_Model->nr_class*(m_Model->nr_class - 1) / 2);
// predict distances vector
svm_predict_values(m_Model, x, (double*) (distances.GetDataPointer()));
// Free allocated memory
delete[] x;
return (distances);
}
template <class TValue, class TLabel>
typename SVMModel<TValue, TLabel>::ProbabilitiesVectorType
SVMModel<TValue, TLabel>::EvaluateProbabilities(const MeasurementType& measure) const
{
// Check if model is up-to-date
if (!m_ModelUpToDate)
{
itkExceptionMacro(<< "Model is not up-to-date, can not predict probabilities");
}
if (!this->HasProbabilities())
{
throw itk::ExceptionObject(__FILE__, __LINE__,
"Model does not support probability estimates", ITK_LOCATION);
}
// Get number of classes
int nr_class = svm_get_nr_class(m_Model);
// Allocate nodes (/TODO if performances problems are related to too
// many allocations, a cache approach can be set)
struct svm_node * x = new struct svm_node[measure.size() + 1];
int valueIndex = 0;
// Fill the node
for (typename MeasurementType::const_iterator mIt = measure.begin(); mIt != measure.end(); ++mIt, ++valueIndex)
{
x[valueIndex].index = valueIndex + 1;
x[valueIndex].value = (*mIt);
}
// Termination node
x[measure.size()].index = -1;
x[measure.size()].value = 0;
double* dec_values = new double[nr_class];
svm_predict_probability(m_Model, x, dec_values);
// Reorder values in increasing class label
int* labels = m_Model->label;
std::vector<int> orderedLabels(nr_class);
std::copy(labels, labels + nr_class, orderedLabels.begin());
std::sort(orderedLabels.begin(), orderedLabels.end());
ProbabilitiesVectorType probabilities(nr_class);
for (int i = 0; i < nr_class; ++i)
{
// svm_predict_probability is such that "dec_values[i]" corresponds to label "labels[i]"
std::vector<int>::iterator it = std::find(orderedLabels.begin(), orderedLabels.end(), labels[i]);
probabilities[it - orderedLabels.begin()] = dec_values[i];
}
// Free allocated memory
delete[] x;
delete[] dec_values;
return probabilities;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::SaveModel(const char* model_file_name) const
{
if (svm_save_model(model_file_name, m_Model) != 0)
{
itkExceptionMacro(<< "Problem while saving SVM model "
<< std::string(model_file_name));
}
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::LoadModel(const char* model_file_name)
{
this->DeleteModel();
m_Model = svm_load_model(model_file_name);
if (m_Model == ITK_NULLPTR)
{
itkExceptionMacro(<< "Problem while loading SVM model "
<< std::string(model_file_name));
}
m_Parameters = m_Model->param;
m_ModelUpToDate = true;
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::PrintSelf(std::ostream& os, itk::Indent indent) const
{
Superclass::PrintSelf(os, indent);
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::SetSupportVectors(svm_node ** sv, int nbOfSupportVector)
{
if (!m_Model)
{
itkExceptionMacro( "Internal SVM model is empty!");
}
// erase the old SV
// delete just the first element, it destoyes the whole pointers (cf SV filling with x_space)
delete[] (m_Model->SV[0]);
for (int n = 0; n < m_Model->l; ++n)
{
m_Model->SV[n] = ITK_NULLPTR;
}
delete[] (m_Model->SV);
m_Model->SV = ITK_NULLPTR;
m_Model->SV = new struct svm_node*[m_Model->l];
// copy new SV values
svm_node **SV = m_Model->SV;
// Compute the total number of SV elements.
unsigned int elements = 0;
for (int p = 0; p < nbOfSupportVector; ++p)
{
//std::cout << p << " ";
const svm_node *tempNode = sv[p];
//std::cout << p << " ";
while (tempNode->index != -1)
{
tempNode++;
++elements;
}
++elements; // for -1 values
}
if (m_Model->l > 0)
{
SV[0] = new struct svm_node[elements];
memcpy(SV[0], sv[0], sizeof(svm_node*) * elements);
}
svm_node *x_space = SV[0];
int j = 0;
for (int i = 0; i < m_Model->l; ++i)
{
// SV
SV[i] = &x_space[j];
const svm_node *p = sv[i];
svm_node * pCpy = SV[i];
while (p->index != -1)
{
pCpy->index = p->index;
pCpy->value = p->value;
++p;
++pCpy;
++j;
}
pCpy->index = -1;
++j;
}
if (m_Model->l > 0)
{
delete[] SV[0];
}
}
template <class TValue, class TLabel>
void
SVMModel<TValue, TLabel>::SetAlpha(double ** alpha, int itkNotUsed(nbOfSupportVector))
{
if (!m_Model)
{
itkExceptionMacro( "Internal SVM model is empty!");
}
// Erase the old sv_coef
for (int i = 0; i < m_Model->nr_class - 1; ++i)
{
delete[] m_Model->sv_coef[i];
}
delete[] m_Model->sv_coef;
// copy new sv_coef values
m_Model->sv_coef = new double*[m_Model->nr_class - 1];
for (int i = 0; i < m_Model->nr_class - 1; ++i)
m_Model->sv_coef[i] = new double[m_Model->l];
for (int i = 0; i < m_Model->l; ++i)
{
// sv_coef
for (int k = 0; k < m_Model->nr_class - 1; ++k)
{
m_Model->sv_coef[k][i] = alpha[k][i];
}
}
}
} // end namespace otb
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment