Commit 1a4bd002 authored by Guillaume Pasero's avatar Guillaume Pasero

REFAC: adapt SVMCrossValidationCostFunction to new model

parent ab9f669c
......@@ -48,7 +48,7 @@ namespace otb
*
* \ingroup ClassificationFilters
*
* \ingroup OTBSVMLearning
* \ingroup OTBSupervised
*/
template <class TModel>
class ITK_EXPORT SVMCrossValidationCostFunction
......@@ -78,10 +78,6 @@ public:
itkSetObjectMacro(Model, SVMModelType);
itkGetObjectMacro(Model, SVMModelType);
/** Set/Get the number of cross validation folders */
itkSetMacro(NumberOfCrossValidationFolders, unsigned int);
itkGetMacro(NumberOfCrossValidationFolders, unsigned int);
/** Set/Get the derivative step */
itkSetMacro(DerivativeStep, ParametersValueType);
itkGetMacro(DerivativeStep, ParametersValueType);
......@@ -103,7 +99,7 @@ protected:
/** Update svm parameters struct according to the input parameters
*/
virtual void UpdateParameters(struct svm_parameter& svm_parameters, const ParametersType& parameters) const;
void UpdateParameters(const ParametersType& parameters) const;
private:
SVMCrossValidationCostFunction(const Self &); //purposely not implemented
......@@ -112,9 +108,6 @@ private:
/**Pointer to the SVM model to optimize */
SVMModelPointer m_Model;
/** Number of cross validation folders */
unsigned int m_NumberOfCrossValidationFolders;
/** Step used to compute the derivatives */
ParametersValueType m_DerivativeStep;
......
......@@ -27,7 +27,7 @@ namespace otb
{
template<class TModel>
SVMCrossValidationCostFunction<TModel>
::SVMCrossValidationCostFunction() : m_Model(), m_NumberOfCrossValidationFolders(10), m_DerivativeStep(0.001)
::SVMCrossValidationCostFunction() : m_Model(), m_DerivativeStep(0.001)
{}
template<class TModel>
SVMCrossValidationCostFunction<TModel>
......@@ -52,9 +52,9 @@ SVMCrossValidationCostFunction<TModel>
}
// Updates vm_parameters according to current parameters
this->UpdateParameters(m_Model->GetParameters(), parameters);
this->UpdateParameters(parameters);
return m_Model->CrossValidation(m_NumberOfCrossValidationFolders);
return m_Model->CrossValidation();
}
template<class TModel>
......@@ -129,31 +129,31 @@ SVMCrossValidationCostFunction<TModel>
{
case LINEAR:
// C
svm_parameters.C = parameters[0];
m_Model->SetC(parameters[0]);
break;
case POLY:
// C, gamma and coef0
svm_parameters.C = parameters[0];
svm_parameters.gamma = parameters[1];
svm_parameters.coef0 = parameters[2];
m_Model->SetC(parameters[0]);
m_Model->SetKernelGamma(parameters[1]);
m_Model->SetKernelCoef0(parameters[2]);
break;
case RBF:
// C and gamma
svm_parameters.C = parameters[0];
svm_parameters.gamma = parameters[1];
m_Model->SetC(parameters[0]);
m_Model->SetKernelGamma(parameters[1]);
break;
case SIGMOID:
// C, gamma and coef0
svm_parameters.C = parameters[0];
svm_parameters.gamma = parameters[1];
svm_parameters.coef0 = parameters[2];
m_Model->SetC(parameters[0]);
m_Model->SetKernelGamma(parameters[1]);
m_Model->SetKernelCoef0(parameters[2]);
break;
default:
svm_parameters.C = parameters[0];
m_Model->SetC(parameters[0]);
break;
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment