Commit 7e46e6f0 authored by Cyrille Valladeau's avatar Cyrille Valladeau
Browse files

Ajout de svm_copy_model et des tests associes.

parent 3b8e4abf
......@@ -200,6 +200,9 @@ public:
/** Loads the model from a file */
void LoadModel(const char* model_file_name);
/** Copy the model */
Pointer GetCopy();
/** Set the SVM type to C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR */
void SetSVMType(int svmtype)
{
......@@ -407,6 +410,12 @@ public:
{
return m_Model->l;
}
/** Set number of support vectors */
void SetNumberOfSupportVectors(int l)
{
m_Model->l = l;
}
/** Return rho values */
double * GetRho(void)
{
......@@ -417,6 +426,11 @@ public:
{
return m_Model->SV;
}
/** Set the support vectors */
void SetSupportVectors(svm_node ** sv)
{
m_Model->SV = sv;
}
/** Return the alphas values (SV Coef) */
double ** GetAlpha (void)
{
......
......@@ -176,6 +176,18 @@ SVMModel<TInputPixel, TLabel>
}
}
template <class TInputPixel, class TLabel >
typename SVMModel<TInputPixel, TLabel>::Pointer
SVMModel<TInputPixel, TLabel>
::GetCopy()
{
Pointer modelCopy = New();
modelCopy->SetModel( svm_copy_model(m_Model) );
return modelCopy;
}
template <class TInputPixel, class TLabel >
void
SVMModel<TInputPixel, TLabel>
......
......@@ -217,8 +217,30 @@ ADD_TEST(leTvSVMImageClassificationFilter ${LEARNING_TESTS3}
${TEMP}/leSVMImageClassificationFilterOutput.tif)
ADD_TEST(leTuSVMModelGenericKernelsTest ${LEARNING_TESTS3}
otbSVMModelGenericKernelsTest
)
otbSVMModelGenericKernelsTest)
ADD_TEST(leTvSVMModelCopy ${LEARNING_TESTS3}
--compare-ascii ${TOL} ${INPUTDATA}/svm_model
${TEMP}/svmcopymodel_test
otbSVMModelCopyTest
${INPUTDATA}/svm_model
${TEMP}/svmcopymodel_test)
ADD_TEST(leTvSVMModelCopyGenericKernel ${LEARNING_TESTS3}
--compare-ascii ${TOL} ${INPUTDATA}/svm_model_generic
${TEMP}/svmcopygeneric_test
otbSVMModelCopyGenericKernelTest
${INPUTDATA}/svm_model_generic
${TEMP}/svmcopygeneric_test)
ADD_TEST(leTvSVMModelCopyComposedKernel ${LEARNING_TESTS3}
--compare-ascii ${TOL} ${INPUTDATA}/svm_model_composed
${TEMP}/svmcopycomposed_test
otbSVMModelCopyComposedKernelTest
${INPUTDATA}/svm_model_composed
${TEMP}/svmcopycomposed_test)
# A enrichir
SET(BasicLearning_SRCS1
......@@ -254,6 +276,9 @@ otbSEMClassifierNew.cxx
otbSVMImageClassificationFilterNew.cxx
otbSVMImageClassificationFilter.cxx
otbSVMModelGenericKernelsTest.cxx
otbSVMModelCopyTest.cxx
otbSVMModelCopyGenericKernelTest.cxx
otbSVMModelCopyComposedKernelTest.cxx
)
......
......@@ -34,4 +34,7 @@ REGISTER_TEST(otbSEMClassifierNew);
REGISTER_TEST(otbSVMImageClassificationFilterNew);
REGISTER_TEST(otbSVMImageClassificationFilter);
REGISTER_TEST(otbSVMModelGenericKernelsTest);
REGISTER_TEST(otbSVMModelCopyTest);
REGISTER_TEST(otbSVMModelCopyGenericKernelTest);
REGISTER_TEST(otbSVMModelCopyComposedKernelTest);
}
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif
#include "itkExceptionObject.h"
#include <iostream>
#include "otbSVMKernels.h"
#include "otbSVMModel.h"
int otbSVMModelCopyComposedKernelTest( int argc, char* argv[] )
{
typedef unsigned char InputPixelType;
typedef unsigned char LabelPixelType;
typedef otb::SVMModel< InputPixelType, LabelPixelType > ModelType;
ModelType::Pointer svmModel = ModelType::New();
svmModel->LoadModel(argv[1]);
ModelType::Pointer svmModelCopy;
svmModelCopy = svmModel->GetCopy();
svmModelCopy->SaveModel(argv[2]);
return EXIT_SUCCESS;
}
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif
#include "itkExceptionObject.h"
#include <iostream>
#include "otbSVMKernels.h"
#include "otbSVMModel.h"
int otbSVMModelCopyGenericKernelTest( int argc, char* argv[] )
{
typedef unsigned char InputPixelType;
typedef unsigned char LabelPixelType;
typedef otb::SVMModel< InputPixelType, LabelPixelType > ModelType;
// Create the model to be copied
ModelType::Pointer svmModel = ModelType::New();
otb::RBFKernelFunctor lFunctor;
svmModel->SetKernelFunctor(&lFunctor);
svmModel->LoadModel(argv[1]);
// Copy the model and print it
ModelType::Pointer svmModelCopy;
svmModelCopy = svmModel->GetCopy();
svmModelCopy->SaveModel(argv[2]);
return EXIT_SUCCESS;
}
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif
#include "itkExceptionObject.h"
#include "otbImage.h"
#include <iostream>
#include "otbSVMModel.h"
int otbSVMModelCopyTest( int argc, char* argv[] )
{
typedef unsigned char InputPixelType;
typedef unsigned char LabelPixelType;
const unsigned int Dimension = 2;
typedef otb::Image< InputPixelType, Dimension > InputImageType;
typedef otb::SVMModel< InputPixelType, LabelPixelType > ModelType;
ModelType::Pointer svmModel = ModelType::New();
svmModel->LoadModel(argv[1]);
ModelType::Pointer svmModelCopy = svmModel->GetCopy();
svmModelCopy->SaveModel(argv[2]);
return EXIT_SUCCESS;
}
......@@ -2702,7 +2702,8 @@ int svm_save_model(const char *model_file_name, const svm_model *model)
int l = model->l;
fprintf(fp, "nr_class %d\n", nr_class);
fprintf(fp, "total_sv %d\n",l);
if(model->rho)
{
fprintf(fp, "rho");
for(int i=0;i<nr_class*(nr_class-1)/2;i++)
......@@ -2914,22 +2915,9 @@ svm_model *svm_load_model(const char *model_file_name, /*otb::*/GenericKernelFun
{
if( param.kernel_type == COMPOSED )
{
/*
if( generic_kernel_functor == NULL )
{
fprintf(stderr,"composed kernel functor is not initialized\n",cmd);
return NULL;
}
*/
//Load generic parameters
delete generic_kernel_functor;
/*
ComposedKernelFunctor * composed;
composed = new ComposedKernelFunctor;
int cr = composed->load_parameters(&fp);
param.kernel_composed = composed;
*/
param.kernel_composed = new ComposedKernelFunctor;
int cr = param.kernel_composed->load_parameters(&fp);
model->delete_composed = true;
......@@ -3029,6 +3017,146 @@ svm_model *svm_load_model(const char *model_file_name, /*otb::*/GenericKernelFun
return model;
}
//************************************************//
// OTB's modifications : fonction entiere ajoutee //
//************************************************//
svm_model *svm_copy_model( const svm_model *model )
{
const svm_parameter& param = model->param;
// instanciated the copy
svm_model *modelCpy = Malloc(svm_model,1);
svm_parameter& paramCpy = modelCpy->param;
modelCpy->rho = NULL;
modelCpy->probA = NULL;
modelCpy->probB = NULL;
modelCpy->label = NULL;
modelCpy->nSV = NULL;
modelCpy->delete_composed = false;
// SVM type copy
paramCpy.svm_type = param.svm_type;
// Kernel type copy
paramCpy.kernel_type = param.kernel_type;
// Param copy
paramCpy.degree = param.degree;
paramCpy.gamma = param.gamma;
paramCpy.coef0 = param.coef0;
// Model variable
int nr_class = model->nr_class;
int l = model->l;
modelCpy->nr_class = nr_class;
modelCpy->l = l;
if(model->rho)
{
int n = model->nr_class * (model->nr_class-1)/2;
modelCpy->rho = Malloc(double,n);
for(int i=0; i<n; i++)
modelCpy->rho[i] = model->rho[i];
}
if(model->label)
{
modelCpy->label = Malloc(int,nr_class);
for(int i=0; i<nr_class; i++)
modelCpy->label[i] = model->label[i];
}
if(model->probA)
{
int n = nr_class * (nr_class-1)/2;
modelCpy->probA = Malloc(double,n);
for(int i=0; i<n; i++)
modelCpy->probA[i] = model->probA[i];
}
if(model->probB)
{
int n = nr_class * (nr_class-1)/2;
modelCpy->probB = Malloc(double,n);
for(int i=0; i<n; i++)
modelCpy->probB[i] = model->probB[i];
}
if(model->nSV)
{
modelCpy->nSV = Malloc(int,nr_class);
for(int i=0;i<nr_class;i++)
modelCpy->nSV[i] = model->nSV[i];
}
// SV copy
const double * const *sv_coef = model->sv_coef;
const svm_node * const *SV = model->SV;
modelCpy->SV = Malloc(svm_node*,l);
svm_node **SVCpy = modelCpy->SV;
modelCpy->sv_coef = Malloc(double *,nr_class-1);
for(int i=0; i<nr_class-1; i++)
modelCpy->sv_coef[i] = Malloc(double,l);
// Compute the total number of SV elements.
unsigned int elements = 0;
for (int p=0; p<l; p++)
{
const svm_node *tempNode = SV[p];
while(tempNode->index != -1)
{
tempNode++;
elements++;
}
elements++;// for -1 values
}
if(l>0)
{
modelCpy->SV[0] = Malloc(svm_node,elements);
memcpy( modelCpy->SV[0],model->SV[0],sizeof(svm_node*)*elements);
}
svm_node *x_space = modelCpy->SV[0];
int j = 0;
for(int i=0; i<l; i++)
{
// sv_coef
for(int k=0; k<nr_class-1; k++)
modelCpy->sv_coef[k][i] = sv_coef[k][i];
// SV
modelCpy->SV[i] = &x_space[j];
const svm_node *p = SV[i];
svm_node *pCpy = SVCpy[i];
while(p->index != -1)
{
pCpy->index = p->index;
pCpy->value = p->value;
p++;
pCpy++;
j++;
}
pCpy->index = -1;
j++;
}
// Generic kernel copy
if( param.kernel_type == GENERIC )
{
// copy
paramCpy.kernel_generic = param.kernel_generic;
}
if( param.kernel_type == COMPOSED )
{
// copy
paramCpy.kernel_composed = param.kernel_composed;
}
return modelCpy;
}
void svm_destroy_model(svm_model* model)
{
if(model->free_sv && model->l > 0)
......@@ -3193,6 +3321,21 @@ int svm_check_probability_model(const svm_model *model)
namespace otb
{
*/
GenericKernelFunctorBase::GenericKernelFunctorBase(const GenericKernelFunctorBase& copy)
{
this->m_MapParameters = copy.m_MapParameters;
this->m_Name = copy.m_Name;
}
GenericKernelFunctorBase&
GenericKernelFunctorBase::operator=(const GenericKernelFunctorBase& copy)
{
this->m_MapParameters = copy.m_MapParameters;
this->m_Name = copy.m_Name;
return *this;
}
int
GenericKernelFunctorBase::
load_parameters(FILE ** pfile)
......@@ -3473,6 +3616,27 @@ add(const svm_node *px, const svm_node *py) const
// ****************************************************************************************
// ************************ ComposedKernelFunctor methods ********************/
// ****************************************************************************************
ComposedKernelFunctor::ComposedKernelFunctor(const ComposedKernelFunctor& copy)
{
//this->GenericKernelFunctorBase::GenericKernelFunctorBase(copy);
Superclass::GenericKernelFunctorBase(static_cast<GenericKernelFunctorBase>(copy));
this->m_KernelFunctorList = copy.m_KernelFunctorList;
this->m_HaveToBeDeletedList = copy.m_HaveToBeDeletedList;
this->m_PonderationList = copy.m_PonderationList;
}
ComposedKernelFunctor&
ComposedKernelFunctor::operator=(const ComposedKernelFunctor& copy)
{
//this->GenericKernelFunctorBase::operator=( copy );
Superclass::operator=( static_cast<GenericKernelFunctorBase>(copy) );
this->m_KernelFunctorList = copy.m_KernelFunctorList;
this->m_HaveToBeDeletedList = copy.m_HaveToBeDeletedList;
this->m_PonderationList = copy.m_PonderationList;
return *this;
}
void
ComposedKernelFunctor
::print_parameters(void)const
......
......@@ -108,6 +108,7 @@ int svm_check_probability_model(const struct svm_model *model);
//OTB's modifications
struct svm_model *svm_load_model(const char *model_file_name, /*otb::*/GenericKernelFunctorBase * generic_kernel_functor = NULL);
struct svm_model *svm_copy_model( const svm_model *model );
#ifdef __cplusplus
}
......@@ -123,11 +124,17 @@ class GenericKernelFunctorBase
{
public:
GenericKernelFunctorBase() : m_Name("FunctorName") {};
/** Recopy constructor */
GenericKernelFunctorBase( const GenericKernelFunctorBase& copy);
GenericKernelFunctorBase& operator=(const GenericKernelFunctorBase& copy);
virtual ~GenericKernelFunctorBase() {};
typedef std::map<std::string,std::string> MapType;
typedef MapType::iterator MapIterator;
typedef MapType::const_iterator MapConstIterator;
typedef GenericKernelFunctorBase Superclass;
typedef std::map<std::string,std::string> MapType;
typedef MapType::iterator MapIterator;
typedef MapType::const_iterator MapConstIterator;
template<class T>
T GetValue(const char *option) const
......@@ -264,7 +271,14 @@ public:
}
}
};
/** Recopy constructor */
ComposedKernelFunctor( const ComposedKernelFunctor& copy );
/* ComposedKernelFunctor( const ComposedKernelFunctor& c ) : GenericKernelFunctorBase(c), */
/* m_KernelFunctorList(c.m_KernelFunctorList) */
/* m_HaveToBeDeletedList(c.m_HaveToBeDeletedList) */
/* m_PonderationList(c.m_PonderationList) {}; */
ComposedKernelFunctor& operator=(const ComposedKernelFunctor& copy);
typedef std::vector<GenericKernelFunctorBase *> KernelListType;
virtual double operator()(const svm_node *x, const svm_node *y, const svm_parameter& param)const // = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment