Skip to content
Snippets Groups Projects
Commit e2d58620 authored by Manuel Grizonnet's avatar Manuel Grizonnet
Browse files

COMP: forgot to push LibSVM classes

parent 4a32fae0
No related branches found
No related tags found
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbSVMMachineLearningModel_h
#define __otbSVMMachineLearningModel_h
#include "itkLightObject.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "itkListSample.h"
#include "otbMachineLearningModel.h"
//include opencv
//#include <opencv.hpp> // opencv general include file
//#include <ml/ml.hpp> // opencv machine learning include file
// SVM estimator
#include "otbSVMSampleListModelEstimator.h"
// Validation
#include "otbSVMClassifier.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT LibSVMMachineLearningModel
: public MachineLearningModel <TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef LibSVMMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
// LibSVM related typedefs
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<InputSampleType> MeasurementVectorFunctorType;
typedef otb::SVMSampleListModelEstimator<InputListSampleType, TargetListSampleType, MeasurementVectorFunctorType>
SVMEstimatorType;
typedef otb::SVMClassifier<InputSampleType, TargetValueType> ClassifierType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(char * filename, const char * name=0);
/** Load the model from file */
virtual void Load(char * filename, const char * name=0);
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*);
/** Determine the file type. Returns true if this ImageIO can write the
* file specified. */
virtual bool CanWriteFile(const char*);
//Setters/Getters to SVM model
// itkGetMacro(SVMType, int);
// itkSetMacro(SVMType, int);
itkGetMacro(KernelType, int);
itkSetMacro(KernelType, int);
itkGetMacro(C, float);
itkSetMacro(C, float);
itkGetMacro(ParameterOptimization, bool);
itkSetMacro(ParameterOptimization, bool);
// itkGetMacro(Epsilon, int);
// itkSetMacro(Epsilon, int);
protected:
/** Constructor */
LibSVMMachineLearningModel();
/** Destructor */
virtual ~LibSVMMachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
private:
LibSVMMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
int m_KernelType;
float m_C;
bool m_ParameterOptimization;
typename SVMEstimatorType::Pointer m_SVMestimator;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbLibSVMMachineLearningModel.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbLibSVMMachineLearningModel_txx
#define __otbLibSVMMachineLearningModel_txx
#include "otbLibSVMMachineLearningModel.h"
//#include "otbOpenCVUtils.h"
// SVM estimator
//#include "otbSVMSampleListModelEstimator.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::LibSVMMachineLearningModel()
{
// m_SVMModel = new CvSVM;
// m_SVMType = CvSVM::C_SVC;
m_KernelType = LINEAR;
// m_TermCriteriaType = CV_TERMCRIT_ITER;
m_C = 1.0;
// m_Epsilon = 1e-6;
m_ParameterOptimization = false;
m_SVMestimator = SVMEstimatorType::New();
}
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::~LibSVMMachineLearningModel()
{
//delete m_SVMModel;
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
// Set up SVM's parameters
// CvSVMParams params;
// params.svm_type = m_SVMType;
// params.kernel_type = m_KernelType;
// params.term_crit = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
// // Train the SVM
m_SVMestimator->SetC(m_C);
m_SVMestimator->SetKernelType(m_KernelType);
m_SVMestimator->SetParametersOptimization(m_ParameterOptimization);
m_SVMestimator->SetInputSampleList(this->GetInputListSample());
m_SVMestimator->SetTrainingSampleList(this->GetTargetListSample());
m_SVMestimator->Update();
}
template <class TInputValue, class TOutputValue>
typename LibSVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
{
TargetSampleType target;
otbMsgDevMacro(<< "Starting iterations ");
MeasurementVectorFunctorType mfunctor;
target = m_SVMestimator->GetModel()->EvaluateLabel(mfunctor(input));
return target;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Save(char * filename, const char * name)
{
m_SVMestimator->GetModel()->SaveModel(filename);
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Load(char * filename, const char * name)
{
m_SVMestimator->GetModel()->LoadModel(filename);
}
template <class TInputValue, class TOutputValue>
bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
} //end namespace otb
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment