Skip to content
Snippets Groups Projects
Commit d5fc257d authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

REFAC: port SVM parameter optimization

parent e859ab59
No related branches found
No related tags found
No related merge requests found
#
# Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
#
# This file is part of Orfeo Toolbox
#
# https://www.orfeo-toolbox.org/
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
set(OTBSVMLearning_SRC
otbExhaustiveExponentialOptimizer.cxx
)
add_library(OTBSVMLearning ${OTBSVMLearning_SRC})
target_link_libraries(OTBSVMLearning
${OTBVectorDataBase_LIBRARIES}
${OTBImageBase_LIBRARIES}
${OTBLibSVM_LIBRARIES}
${OTBStreaming_LIBRARIES}
${OTBCommon_LIBRARIES}
)
otb_module_target(OTBSVMLearning)
......@@ -42,9 +42,9 @@ namespace otb
*
* \ingroup Numerics Optimizers
*
* \ingroup OTBSVMLearning
* \ingroup OTBSupervised
*/
class ITK_EXPORT ExhaustiveExponentialOptimizer :
class ITK_ABI_EXPORT ExhaustiveExponentialOptimizer :
public itk::SingleValuedNonLinearOptimizer
{
public:
......@@ -59,7 +59,7 @@ public:
itkNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(ExhaustiveExponentialOptimizer, SingleValuedNonLinearOptimizer);
itkTypeMacro(ExhaustiveExponentialOptimizer,itk::SingleValuedNonLinearOptimizer);
void StartOptimization(void) ITK_OVERRIDE;
......
......@@ -225,6 +225,10 @@ public:
itkSetMacro(FineOptimizationNumberOfSteps, unsigned int);
itkGetMacro(FineOptimizationNumberOfSteps, unsigned int);
unsigned int GetNumberOfKernelParameters();
double CrossValidation(void);
protected:
/** Constructor */
LibSVMMachineLearningModel();
......@@ -250,8 +254,6 @@ private:
void DeleteModel(void);
double CrossValidation(void);
void OptimizeParameters(void);
/** Container to hold the SVM model itself */
......
......@@ -379,6 +379,38 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
m_Model = ITK_NULLPTR;
}
template <class TInputValue, class TOutputValue>
unsigned int
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::GetNumberOfKernelParameters()
{
unsigned int nb = 1;
switch(this->GetKernelType())
{
case LINEAR:
// C
nb = 1;
break;
case POLY:
// C, gamma and coef0
nb = 3;
break;
case RBF:
// C and gamma
nb = 2;
break;
case SIGMOID:
// C, gamma and coef0
nb = 3;
break;
default:
// C
nb = 1;
break;
}
return nb;
}
template <class TInputValue, class TOutputValue>
double
LibSVMMachineLearningModel<TInputValue,TOutputValue>
......@@ -397,7 +429,7 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
double total_correct = 0.;
for (int i = 0; i < length; ++i)
{
if (target[i] == m_Problem.y[i])
if (m_TmpTarget[i] == m_Problem.y[i])
{
++total_correct;
}
......@@ -413,57 +445,25 @@ void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::OptimizeParameters()
{
typedef SVMCrossValidationCostFunction<this> CrossValidationFunctionType;
typedef SVMCrossValidationCostFunction<LibSVMMachineLearningModel<TInputValue,TOutputValue> > CrossValidationFunctionType;
typename CrossValidationFunctionType::Pointer crossValidationFunction = CrossValidationFunctionType::New();
crossValidationFunction->SetModel(this);
typename CrossValidationFunctionType::ParametersType initialParameters, coarseBestParameters, fineBestParameters;
switch (this->GetKernelType())
{
case LINEAR:
// C
initialParameters.SetSize(1);
initialParameters[0] = this->GetC();
break;
case POLY:
// C, gamma and coef0
initialParameters.SetSize(3);
initialParameters[0] = this->GetC();
initialParameters[1] = this->GetKernelGamma();
initialParameters[2] = this->GetKernelCoef0();
break;
case RBF:
// C and gamma
initialParameters.SetSize(2);
initialParameters[0] = this->GetC();
initialParameters[1] = this->GetKernelGamma();
break;
case SIGMOID:
// C, gamma and coef0
initialParameters.SetSize(3);
initialParameters[0] = this->GetC();
initialParameters[1] = this->GetKernelGamma();
initialParameters[2] = this->GetKernelCoef0();
break;
default:
// Only C
initialParameters.SetSize(1);
initialParameters[0] = this->GetC();
break;
}
unsigned int nbParams = this->GetNumberOfKernelParameters();
initialParameters.SetSize(nbParams);
initialParameters[0] = this->GetC();
if (nbParams > 1) initialParameters[1] = this->GetKernelGamma();
if (nbParams > 2) initialParameters[2] = this->GetKernelCoef0();
m_InitialCrossValidationAccuracy = crossValidationFunction->GetValue(initialParameters);
m_FinalCrossValidationAccuracy = m_InitialCrossValidationAccuracy;
otbMsgDebugMacro(<< "Initial accuracy : " << m_InitialCrossValidationAccuracy
<< ", Parameters Optimization" << m_ParametersOptimization);
<< ", Parameters Optimization" << m_ParameterOptimization);
if (m_ParametersOptimization)
if (m_ParameterOptimization)
{
otbMsgDebugMacro(<< "Model parameters optimization");
typename ExhaustiveExponentialOptimizer::Pointer coarseOptimizer = ExhaustiveExponentialOptimizer::New();
......@@ -503,38 +503,9 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
m_FinalCrossValidationAccuracy = fineOptimizer->GetMaximumMetricValue();
switch (this->GetKernelType())
{
case LINEAR:
// C
this->SetC(fineBestParameters[0]);
break;
case POLY:
// C, gamma and coef0
this->SetC(fineBestParameters[0]);
this->SetKernelGamma(fineBestParameters[1]);
this->SetKernelCoef0(fineBestParameters[2]);
break;
case RBF:
// C and gamma
this->SetC(fineBestParameters[0]);
this->SetKernelGamma(fineBestParameters[1]);
break;
case SIGMOID:
// C, gamma and coef0
this->SetC(fineBestParameters[0]);
this->SetKernelGamma(fineBestParameters[1]);
this->SetKernelCoef0(fineBestParameters[2]);
break;
default:
// Only C
this->SetC(fineBestParameters[0]);
break;
}
this->SetC(fineBestParameters[0]);
if (nbParams > 1) this->SetKernelGamma(fineBestParameters[1]);
if (nbParams > 2) this->SetKernelCoef0(fineBestParameters[2]);
}
}
......
......@@ -21,7 +21,6 @@
#ifndef otbSVMCrossValidationCostFunction_h
#define otbSVMCrossValidationCostFunction_h
#include "otbSVMModel.h"
#include "itkSingleValuedCostFunction.h"
namespace otb
......
......@@ -22,6 +22,7 @@
#define otbSVMCrossValidationCostFunction_txx
#include "otbSVMCrossValidationCostFunction.h"
#include "otbMacro.h"
namespace otb
{
......@@ -95,67 +96,18 @@ SVMCrossValidationCostFunction<TModel>
{
itkExceptionMacro(<< "Model is null, can not evaluate number of parameters.");
}
switch (m_Model->GetKernelType())
{
case LINEAR:
// C
return 1;
case POLY:
// C, gamma and coef0
return 3;
case RBF:
// C and gamma
return 2;
case SIGMOID:
// C, gamma and coef0
return 3;
default:
// C
return 1;
}
return m_Model->GetNumberOfKernelParameters();
}
template<class TModel>
void
SVMCrossValidationCostFunction<TModel>
::UpdateParameters(struct svm_parameter& svm_parameters, const ParametersType& parameters) const
::UpdateParameters(const ParametersType& parameters) const
{
switch (m_Model->GetKernelType())
{
case LINEAR:
// C
m_Model->SetC(parameters[0]);
break;
case POLY:
// C, gamma and coef0
m_Model->SetC(parameters[0]);
m_Model->SetKernelGamma(parameters[1]);
m_Model->SetKernelCoef0(parameters[2]);
break;
case RBF:
// C and gamma
m_Model->SetC(parameters[0]);
m_Model->SetKernelGamma(parameters[1]);
break;
case SIGMOID:
// C, gamma and coef0
m_Model->SetC(parameters[0]);
m_Model->SetKernelGamma(parameters[1]);
m_Model->SetKernelCoef0(parameters[2]);
break;
default:
m_Model->SetC(parameters[0]);
break;
}
unsigned int nbParams = m_Model->GetNumberOfKernelParameters();
m_Model->SetC(parameters[0]);
if (nbParams > 1) m_Model->SetKernelGamma(parameters[1]);
if (nbParams > 2) m_Model->SetKernelCoef0(parameters[2]);
}
} // namespace otb
......
......@@ -20,6 +20,7 @@
set(OTBSupervised_SRC
otbMachineLearningModelFactoryBase.cxx
otbExhaustiveExponentialOptimizer.cxx
)
if(OTB_USE_OPENCV)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment