Commit 1c032854 authored by Jordi Inglada's avatar Jordi Inglada

Ajout de la classif SVM

parent 132a88aa
# Sources of non-templated classes.
FILE(GLOB OTBLearning_SRCS "*.cxx" )
ADD_LIBRARY(OTBLearning ${OTBLearning_SRCS})
TARGET_LINK_LIBRARIES (OTBLearning OTBCommon OTBIO ITKCommon ITKIO)
INSTALL_TARGETS(/lib/otb OTBLearning )
INSTALL_FILES(/include/otb/Learning "(\\.h|\\.txx)$")
/*=========================================================================
Program : OTB (ORFEO ToolBox)
Authors : CNES - J. Inglada
Language : C++
Date : 26 April 2006
Version :
Role : SVM Classifier
$Id$
=========================================================================*/
#ifndef __otbSVMClassifier_h
#define __otbSVMClassifier_h
#include "itkSampleClassifier.h"
#include "otbSVMModel.h"
#include "itkVectorImage.h"
namespace otb{
/** \class SVMClassifier
* \brief SVM-based classifier
*
* The first template argument is the type of the target sample data
* that this classifier will assign a class label for each measurement
* vector. The second one is the type of a membership value calculator
* for each. A membership calculator represents a specific knowledge about
* a class. In other words, it should tell us how "likely" is that a
* measurement vector (pattern) belong to the class. The third argument
* is the type of decision rule. The main role of a decision rule is
* comparing the return values of the membership calculators. However,
* decision rule can include some prior knowledge that can improve the
* result.
*
* Before you call the GenerateData method to start the classification process,
* you should plug in all necessary parts ( one or more membership
* calculators, a decision rule, and a target sample data). To plug in
* the decision rule, you use SetDecisionRule method, for the target sample
* data, SetSample method, and for the membership calculators, use
* AddMembershipCalculator method.
*
* As the method name indicates, you can have more than one membership
* calculator. One for each classes. The order you put the membership
* calculator becomes the class label for the class that is represented
* by the membership calculator.
*
* The classification result is stored in a vector of Subsample object.
* Each class has its own class sample (Subsample object) that has
* InstanceIdentifiers for all measurement vectors belong to the class.
* The InstanceIdentifiers come from the target sample data. Therefore,
* the Subsample objects act as separate class masks.
*
* <b>Recent API changes:</b>
* The static const macro to get the length of a measurement vector,
* \c MeasurementVectorSize has been removed to allow the length of a measurement
* vector to be specified at run time. Please use the function
* GetSample().GetMeasurementVectorSize() instead.
*
*/
template< class TSample >
class ITK_EXPORT SVMClassifier :
public itk::Statistics::SampleClassifier< TSample >
{
public:
/** Standard class typedef*/
typedef SVMClassifier Self;
typedef itk::Statistics::SampleClassifier< TSample > Superclass;
typedef itk::SmartPointer< Self > Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macros */
itkTypeMacro(SVMClassifier, itk::Statistics::SampleClassifier);
itkNewMacro(Self) ;
/** Output type for GetClassSample method */
typedef itk::Statistics::MembershipSample< TSample > OutputType ;
/** typedefs from TSample object */
typedef typename TSample::MeasurementType MeasurementType ;
typedef typename TSample::MeasurementVectorType MeasurementVectorType ;
/** typedefs from Superclass */
typedef typename Superclass::MembershipFunctionPointerVector
MembershipFunctionPointerVector ;
typedef unsigned int ClassLabelType ;
typedef std::vector< ClassLabelType > ClassLabelVectorType ;
// /** Sets the target data that will be classified by this */
// void SetSample(const TSample* sample) ;
// /** Returns the target data */
// const TSample* GetSample() const;
// /** Sets the user given class labels for membership functions.
// * If users do not provide class labels for membership functions by calling
// * this function, then the index of the membership function vector for a
// * membership function will be used as class label of measurement vectors
// * belong to the membership function */
// void SetMembershipFunctionClassLabels( ClassLabelVectorType& labels) ;
// /** Gets the user given class labels */
// ClassLabelVectorType& GetMembershipFunctionClassLabels()
// { return m_ClassLabels ; }
// /** Returns the classification result */
OutputType* GetOutput() ;
/** Type definitions for the SVM Model. */
typedef itk::VectorImage< float, 2 > TInputImage;
typedef SVMModel< MeasurementVectorType > SVMModelType;
typedef typename SVMModelType::Pointer SVMModelPointer;
/** Set the model */
itkSetMacro(Model, SVMModelPointer);
/** Get the number of classes. */
itkGetConstReferenceMacro(Model, SVMModelPointer);
void Update() ;
protected:
SVMClassifier() ;
virtual ~SVMClassifier() {}
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Starts the classification process */
void GenerateData() ;
void DoClassification() ;
private:
/** Target data sample pointer*/
const TSample* m_Sample ;
/** Output pointer (MembershipSample) */
typename OutputType::Pointer m_Output ;
/** User given class labels for membership functions */
ClassLabelVectorType m_ClassLabels ;
SVMModelPointer m_Model;
} ; // end of class
} // end of namespace otb
#ifndef ITK_MANUAL_INSTANTIATION
#include "otbSVMClassifier.txx"
#endif
#endif
/*=========================================================================
Program : OTB (ORFEO ToolBox)
Authors : CNES - J. Inglada
Language : C++
Date : 26 April 2006
Version :
Role : SVM Classifier
$Id$
=========================================================================*/
#ifndef __otbSVMClassifier_txx
#define __otbSVMClassifier_txx
#include "otbSVMClassifier.h"
namespace otb{
template< class TSample >
SVMClassifier< TSample >
::SVMClassifier()
{
m_Sample = 0 ;
m_Output = OutputType::New() ;
m_Model = SVMModelType::New();
}
template< class TSample >
void
SVMClassifier< TSample >
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
Superclass::PrintSelf(os,indent);
// os << indent << "Sample: " ;
// if ( m_Sample != 0 )
// {
// os << m_Sample << std::endl;
// }
// else
// {
// os << "not set." << std::endl ;
// }
// os << indent << "Output: " << m_Output << std::endl;
}
// template< class TSample >
// void
// SVMClassifier< TSample >
// ::SetSample(const TSample* sample)
// {
// std::cout << "SVMClassifier::SetSample enter" << std::endl;
// if ( m_Sample != sample )
// {
// m_Sample = sample ;
// m_Output->SetSample(sample) ;
// }
// std::cout << "SVMClassifier::SetSample exit" << std::endl;
// }
// template< class TSample >
// const TSample*
// SVMClassifier< TSample >
// ::GetSample() const
// {
// return m_Sample ;
// }
// template< class TSample >
// void
// SVMClassifier< TSample >
// ::SetMembershipFunctionClassLabels(ClassLabelVectorType& labels)
// {
// m_ClassLabels = labels ;
// }
template< class TSample >
void
SVMClassifier< TSample >
::Update()
{
this->GenerateData();
}
template< class TSample >
void
SVMClassifier< TSample >
::GenerateData()
{
//std::cout << "Before Resize 0" << std::endl;
/* unsigned int i ;
typename TSample::ConstIterator iter = this->GetSample()->Begin() ;
typename TSample::ConstIterator end = this->GetSample()->End() ;
typename TSample::MeasurementVectorType measurements ;
*/
//std::cout << "Before Resize " << std::endl;
m_Output->SetSample(this->GetSample()) ;
//std::cout << "m_Output " << m_Output << std::endl;
m_Output->Resize( this->GetSample()->Size() ) ;
//std::cout << "Resize to " << this->GetSample()->Size() << std::endl;
//std::cout << "Resize to " << m_Output->GetSample()->Size() << std::endl;
//std::vector< double > discriminantScores ;
unsigned int numberOfClasses = this->GetNumberOfClasses() ;
//std::cout << "NbClass " << numberOfClasses << std::endl;
//discriminantScores.resize(numberOfClasses) ;
//unsigned int classLabel ;
m_Output->SetNumberOfClasses(numberOfClasses) ;
/*typename Superclass::DecisionRuleType::Pointer rule =
this->GetDecisionRule() ;*/
//std::cout << "Do Classif " << std::endl;
this->DoClassification();
//std::cout << "End of classif" << std::endl;
// if ( m_ClassLabels.size() != this->GetNumberOfMembershipFunctions() )
// {
// while (iter != end)
// {
// measurements = iter.GetMeasurementVector() ;
// for (i = 0 ; i < numberOfClasses ; i++)
// {
// discriminantScores[i] =
// (this->GetMembershipFunction(i))->Evaluate(measurements) ;
// }
// classLabel = rule->Evaluate(discriminantScores) ;
// m_Output->AddInstance(classLabel, iter.GetInstanceIdentifier()) ;
// ++iter ;
// }
// }
// else
// {
// while (iter != end)
// {
// measurements = iter.GetMeasurementVector() ;
// for (i = 0 ; i < numberOfClasses ; i++)
// {
// discriminantScores[i] =
// (this->GetMembershipFunction(i))->Evaluate(measurements) ;
// }
// classLabel = rule->Evaluate(discriminantScores) ;
// m_Output->AddInstance(m_ClassLabels[classLabel],
// iter.GetInstanceIdentifier()) ;
// ++iter ;
// }
// }
}
template< class TSample >
typename SVMClassifier< TSample >::OutputType*
SVMClassifier< TSample >
::GetOutput()
{
return m_Output ;
}
template< class TSample >
void
SVMClassifier< TSample >
::DoClassification()
{
typename TSample::ConstIterator iter = this->GetSample()->Begin() ;
typename TSample::ConstIterator end = this->GetSample()->End() ;
typename OutputType::ConstIterator iterO = m_Output->Begin() ;
typename OutputType::ConstIterator endO = m_Output->End() ;
typename TSample::MeasurementVectorType measurements ;
int numberOfComponentsPerSample = iter.GetMeasurementVector().Size() ;//this->GetSample().GetMeasurementVectorSize();//
int max_line_len = 1024;
struct svm_node *x;
int max_nr_attr = 64;
bool predict_probability = 1;
const struct svm_model* model = m_Model->GetModel();
// char* line = (char *) malloc(max_line_len*sizeof(char));
// x = (struct svm_node *) malloc(max_nr_attr*sizeof(struct
// svm_node));
m_Model->AllocateProblem(1, numberOfComponentsPerSample);
x = m_Model->GetXSpace();
//std::cout << "XSpace Allocated" << std::endl;
if(svm_check_probability_model(model)==0)
{
throw itk::ExceptionObject(__FILE__, __LINE__,
"Model does not support probabiliy estimates",ITK_LOCATION);
// predict_probability=0;
}
int correct = 0;
int total = 0;
double error = 0;
double sumv = 0, sumy = 0, sumvv = 0, sumyy = 0, sumvy = 0;
int svm_type=svm_get_svm_type(model);
//std::cout << "SVM Type = " << svm_type << std::endl;
int nr_class=svm_get_nr_class(model);
//std::cout << "SVM nr_class = " << nr_class << std::endl;
int *labels=(int *) malloc(nr_class*sizeof(int));
double *prob_estimates=NULL;
int j;
if(predict_probability)
{
if (svm_type==NU_SVR || svm_type==EPSILON_SVR)
printf("Prob. model for test data: target value = predicted value + z,\nz: Laplace distribution e^(-|z|/sigma)/(2sigma),sigma=%g\n",svm_get_svr_probability(model));
else
{
svm_get_labels(model,labels);
prob_estimates = (double *) malloc(nr_class*sizeof(double));
/*fprintf(output,"labels");
for(j=0;j<nr_class;j++)
fprintf(output," %d",labels[j]);
fprintf(output,"\n");*/
}
}
// while(1)
//std::cout << "Starting iterations " << std::endl;
while (iter != end && iterO != endO)
{
int i = 0;
int c;
double target,v;
/*if (fscanf(input,"%lf",&target)==EOF)
break;*/
// while(1)
// {
// if(i>=max_nr_attr-1) // need one more for index = -1
// {
// max_nr_attr *= 2;
// x = (struct svm_node *) realloc(x,max_nr_attr*sizeof(struct svm_node));
// }
// do {
// c = getc(input);
// if(c=='\n' || c==EOF) goto out2;
// } while(isspace(c));
// ungetc(c,input);
// fscanf(input,"%d:%lf",&x[i].index,&x[i].value);
// ++i;
// }
// out2:
// x[i++].index = -1;
measurements = iter.GetMeasurementVector() ;
//std::cout << "Loop on components " << svm_type << std::endl;
for(i=0; i<numberOfComponentsPerSample; i++)
{
//std::cout << i << " " << measurements[i] << std::endl;
//std::cout << "Index "<< x[i].index << std::endl;
//std::cout << "Value "<< x[i].value << std::endl;
x[i].index = i+1 ;
x[i].value = measurements[i];
//std::cout << "Index "<< x[i].index << std::endl;
//std::cout << "Value "<< x[i].value << std::endl;
//std::cout << "-------------------" << std::endl;
}
//std::cout << "Starting prediction" << std::endl;
if (predict_probability && (svm_type==C_SVC || svm_type==NU_SVC))
{
//std::cout << "With predict" << std::endl;
v = svm_predict_probability(model,x,prob_estimates);
//std::cout << "Value : " << v << std::endl;
/*fprintf(output,"%g ",v);
for(j=0;j<nr_class;j++)
fprintf(output,"%g ",prob_estimates[j]);
fprintf(output,"\n");*/
}
else
{
//std::cout << "Without predict" << std::endl;
v = svm_predict(model,x);
//std::cout << "Value : " << v << std::endl;
//fprintf(output,"%g\n",v);
}
unsigned int classLabel;
if(nr_class == 2)
classLabel = static_cast<unsigned int>(v+2);
else
classLabel = static_cast<unsigned int>(v);
// std::cout << "Add instance " << classLabel << std::endl;
//std::cout << "Add instance ident " << iterO.GetInstanceIdentifier() << std::endl;
m_Output->AddInstance(classLabel, iterO.GetInstanceIdentifier()) ;
//std::cout << "After add instance " << iterO.GetClassLabel() << std::endl;
++iter;
++iterO;
}
//std::cout << "End of iterations " << std::endl;
if(predict_probability)
{
free(prob_estimates);
free(labels);
}
//std::cout << "End of iterations and free" << std::endl;
// free(x);
}
} // end of namespace itk
#endif
/*=========================================================================
Program : OTB (ORFEO ToolBox)
Authors : CNES - J. Inglada
Language : C++
Date : 26 April 2006
Version :
Role : SVM Image Model Estimator
$Id$
=========================================================================*/
#ifndef _otbSVMImageModelEstimator_h
#define _otbSVMImageModelEstimator_h
#include "itkImageModelEstimatorBase.h"
#include "itkImageRegionIterator.h"
#include "otbSVMModel.h"
#include "otbSVMMembershipFunction.h"
namespace otb
{
/** \class SVMImageModelEstimator
* \brief Class for SVM model estimation from images used for classification.
*
*
* The basic functionality of the SVMImageModelEstimator framework base class is to
* generate the models used in SVM classification. It requires input
* images and a training image to be provided by the user.
* This object supports data handling of multiband images. The object
* accepts the input image in vector format only, where each pixel is a
* vector and each element of the vector corresponds to an entry from
* 1 particular band of a multiband dataset. A single band image is treated
* as a vector image with a single element for every vector. The classified
* image is treated as a single band scalar image.
*
* EstimateModels() is a pure virtual function making this an abstract class.
* The template parameter is the type of a membership function the
* ImageModelEstimator populates.
*
* A membership function represents a specific knowledge about
* a class. In other words, it should tell us how "likely" is that a
* measurement vector (pattern) belong to the class.
*
* As the method name indicates, you can have more than one membership
* function. One for each classes. The order you put the membership
* calculator becomes the class label for the class that is represented
* by the membership calculator.
*
*
* \ingroup ClassificationFilters
*/
template <class TInputImage,
class TMembershipFunction,
class TTrainingImage>
class ITK_EXPORT SVMImageModelEstimator:
public itk::ImageModelEstimatorBase<TInputImage, TMembershipFunction>
{
public:
/** Standard class typedefs. */
typedef SVMImageModelEstimator Self;
typedef itk::ImageModelEstimatorBase<TInputImage,TMembershipFunction> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Method for creation through the object factory. */
itkNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(SVMImageModelEstimator, itk::ImageModelEstimatorBase);
/** Type definition for the input image. */
typedef typename TInputImage::Pointer InputImagePointer;
/** Type definitions for the training image. */
typedef typename TTrainingImage::Pointer TrainingImagePointer;
/** Type definition for the vector associated with
* input image pixel type. */
typedef typename TInputImage::PixelType InputImagePixelType;
/** Type definitions for the vector holding
* training image pixel type. */
typedef typename TTrainingImage::PixelType TrainingImagePixelType;
/** Type definitions for the iterators for the input and training images. */
typedef
itk::ImageRegionIterator< TInputImage > InputImageIterator;
typedef
itk::ImageRegionIterator< TTrainingImage > TrainingImageIterator;
/** Type definitions for the membership function . */
typedef typename TMembershipFunction::Pointer MembershipFunctionPointer ;
/** Set the training image. */
itkSetMacro(TrainingImage,TrainingImagePointer);
/** Get the training image. */
itkGetMacro(TrainingImage,TrainingImagePointer);
/** Set the number of classes. */
itkSetMacro(NumberOfClasses, unsigned int);
/** Get the number of classes. */
itkGetConstReferenceMacro(NumberOfClasses, unsigned int);
/** Type definitions for the SVM Model. */
typedef itk::Vector< float, 3 > MeasurementVectorType ;
typedef SVMModel< MeasurementVectorType > SVMModelType;
typedef typename SVMModelType::Pointer SVMModelPointer;
/** Set the model */
itkSetMacro(Model, SVMModelPointer);
/** Get the number of classes. */
itkGetConstReferenceMacro(Model, SVMModelPointer);
/** Save the estimated model */
void SaveModel(const char* model_file_name);
protected:
SVMImageModelEstimator();
~SVMImageModelEstimator();
virtual void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Starts the image modelling process */
void GenerateData() ;
private:
SVMImageModelEstimator(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
typedef vnl_matrix<double> MatrixType;
typedef vnl_vector<double> VectorType;
typedef typename TInputImage::SizeType InputImageSizeType;
/** Dimension of the each individual pixel vector. */
// itkStaticConstMacro(VectorDimension, unsigned int,
// InputImagePixelType::Dimension);
// typedef vnl_matrix_fixed<double,1,itkGetStaticConstMacro(VectorDimension)> ColumnVectorType;
unsigned int m_NumberOfClasses;
TrainingImagePointer m_TrainingImage;
/** A function that generates the
* model based on the training input data
* Achieves the goal of training the classifier. */
virtual void EstimateModels();
void BuildProblem();
SVMModelPointer m_Model;
struct svm_parameter param;
struct svm_problem prob;
//struct svm_model *model;
struct svm_node* x_space;
bool m_Done;
}; // class SVMImageModelEstimator
} // namespace otb
#ifndef ITK_MANUAL_INSTANTIATION
#include "otbSVMImageModelEstimator.txx"
#endif
#endif
/*=========================================================================
Program : OTB (ORFEO ToolBox)
Authors : CNES - J. Inglada
Language : C++
Date : 26 April 2006
Version :
Role : SVM Image Model Estimator
$Id$
=========================================================================*/
#ifndef _otbSVMImageModelEstimator_txx
#define _otbSVMImageModelEstimator_txx
#include "otbSVMImageModelEstimator.h"
#include "itkCommand.h"
#include "itkImageRegionConstIterator.h"
namespace otb
{
template<class TInputImage,
class TMembershipFunction,
class TTrainingImage>
SVMImageModelEstimator<TInputImage, TMembershipFunction, TTrainingImage>
::SVMImageModelEstimator(void):
m_NumberOfClasses( 0 )
{
// FIXME initialize SVMModel