Skip to content
Snippets Groups Projects
Commit d80796c0 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

REFAC: work on CrossValidation in LibSVMMachineLearningModel

parent 1a4bd002
No related branches found
No related tags found
No related merge requests found
......@@ -212,6 +212,19 @@ public:
return static_cast<int>(m_Parameters.cache_size);
}
itkSetMacro(CVFolders, unsigned int);
itkGetMacro(CVFolders, unsigned int);
itkGetMacro(InitialCrossValidationAccuracy, double);
itkGetMacro(FinalCrossValidationAccuracy, double);
itkSetMacro(CoarseOptimizationNumberOfSteps, unsigned int);
itkGetMacro(CoarseOptimizationNumberOfSteps, unsigned int);
itkSetMacro(FineOptimizationNumberOfSteps, unsigned int);
itkGetMacro(FineOptimizationNumberOfSteps, unsigned int);
protected:
/** Constructor */
LibSVMMachineLearningModel();
......@@ -237,7 +250,7 @@ private:
void DeleteModel(void);
double CrossValidation(unsigned int nbFolders);
double CrossValidation(void);
void OptimizeParameters(void);
......@@ -253,6 +266,21 @@ private:
/** Do parameters optimization, default : false */
bool m_ParameterOptimization;
/** Number of Cross Validation folders*/
unsigned int m_CVFolders;
/** Initial cross validation accuracy */
double m_InitialCrossValidationAccuracy;
/** Final cross validationa accuracy */
double m_FinalCrossValidationAccuracy;
/** Number of steps for the coarse search */
unsigned int m_CoarseOptimizationNumberOfSteps;
/** Number of steps for the fine search */
unsigned int m_FineOptimizationNumberOfSteps;
/** Temporary array to store cross-validation results */
std::vector<double> m_TmpTarget;
......
......@@ -23,6 +23,8 @@
#include <fstream>
#include "otbLibSVMMachineLearningModel.h"
#include "otbSVMCrossValidationCostFunction.h"
#include "otbExhaustiveExponentialOptimizer.h"
#include "otbMacro.h"
namespace otb
......@@ -46,6 +48,11 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
this->SetCacheSize(40); // MB
this->m_ParameterOptimization = false;
this->m_IsRegressionSupported = true;
this->SetCVFolders(5);
this->m_InitialCrossValidationAccuracy = 0.;
this->m_FinalCrossValidationAccuracy = 0.;
this->m_CoarseOptimizationNumberOfSteps = 5;
this->m_FineOptimizationNumberOfSteps = 5;
this->m_Parameters.nr_weight = 0;
this->m_Parameters.weight_label = ITK_NULLPTR;
......@@ -375,7 +382,7 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
double
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CrossValidation(unsigned int nbFolders)
::CrossValidation(void)
{
double accuracy = 0.0;
// Get the length of the problem
......@@ -384,7 +391,7 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
return accuracy;
// Do cross validation
svm_cross_validation(&m_Problem, &m_Parameters, nbFolders, &m_TmpTarget[0]);
svm_cross_validation(&m_Problem, &m_Parameters, m_CVFolders, &m_TmpTarget[0]);
// Evaluate accuracy
double total_correct = 0.;
......@@ -406,12 +413,9 @@ void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::OptimizeParameters()
{
typedef SVMCrossValidationCostFunction<SVMModelType> CrossValidationFunctionType;
typedef SVMCrossValidationCostFunction<this> CrossValidationFunctionType;
typename CrossValidationFunctionType::Pointer crossValidationFunction = CrossValidationFunctionType::New();
crossValidationFunction->SetModel(this->GetModel());
crossValidationFunction->SetNumberOfCrossValidationFolders(m_NumberOfCrossValidationFolders);
crossValidationFunction->SetModel(this);
typename CrossValidationFunctionType::ParametersType initialParameters, coarseBestParameters, fineBestParameters;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment