Skip to content
Snippets Groups Projects
Commit 33d0fc6e authored by Arnaud Jaen's avatar Arnaud Jaen
Browse files

MRG: OpenCV adapters

parents dfe8db9d b2d74efe
No related branches found
No related tags found
No related merge requests found
Showing
with 1837 additions and 0 deletions
# Sources of non-templated classes.
FILE(GLOB OTBMachineLearning_SRCS "*.cxx" )
ADD_LIBRARY(OTBMachineLearning ${OTBMachineLearning_SRCS})
MESSAGE("LIBS: ${OpenCV_LIBS}")
TARGET_LINK_LIBRARIES (OTBMachineLearning OTBCommon OTBLearning ${OpenCV_LIBS})
IF(OTB_LIBRARY_PROPERTIES)
SET_TARGET_PROPERTIES(OTBMachineLearning PROPERTIES ${OTB_LIBRARY_PROPERTIES})
ENDIF(OTB_LIBRARY_PROPERTIES)
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbBoostMachineLearningModel_h
#define __otbBoostMachineLearningModel_h
#include "itkLightObject.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "itkListSample.h"
#include "otbMachineLearningModel.h"
//include opencv
#include <cv.h> // opencv general include file
#include <ml.h> // opencv machine learning include file
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT BoostMachineLearningModel
: public MachineLearningModel <TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef BoostMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(BoostMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(char * filename, const char * name=0);
/** Load the model from file */
virtual bool Load(char * filename, const char * name=0);
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*);
/** Determine the file type. Returns true if this ImageIO can write the
* file specified. */
virtual bool CanWriteFile(const char*);
protected:
/** Constructor */
BoostMachineLearningModel();
/** Destructor */
virtual ~BoostMachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
private:
BoostMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
CvBoost * m_BoostModel;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbBoostMachineLearningModel.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbBoostMachineLearningModel_txx
#define __otbBoostMachineLearningModel_txx
#include "otbBoostMachineLearningModel.h"
#include "otbOpenCVUtils.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
BoostMachineLearningModel<TInputValue,TOutputValue>
::BoostMachineLearningModel()
{
m_BoostModel = new CvBoost;
}
template <class TInputValue, class TOutputValue>
BoostMachineLearningModel<TInputValue,TOutputValue>
::~BoostMachineLearningModel()
{
delete m_BoostModel;
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
cv::Mat labels;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
CvBoostParams params;
params.boost_type = CvBoost::DISCRETE;
params.split_criteria = CvBoost::DEFAULT;
//train the Boost model
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
m_BoostModel->train(samples,CV_ROW_SAMPLE,labels,cv::Mat(),cv::Mat(),var_type,cv::Mat(),params);
}
template <class TInputValue, class TOutputValue>
typename BoostMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
BoostMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
otb::SampleToMat<InputSampleType>(input,sample);
cv::Mat missing = cv::Mat(1,input.Size(), CV_8U );
missing.setTo(0);
double result = m_BoostModel->predict(sample,missing);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(result);
return target;
}
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::Save(char * filename, const char * name)
{
m_BoostModel->save(filename, name);
}
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::Load(char * filename, const char * name)
{
m_BoostModel->load(filename, name);
}
template <class TInputValue, class TOutputValue>
bool
BoostMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
bool
BoostMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
} //end namespace otb
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbKNearestNeighborsMachineLearningModel_h
#define __otbKNearestNeighborsMachineLearningModel_h
#include "itkLightObject.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "itkListSample.h"
#include "otbMachineLearningModel.h"
//include opencv
#include <opencv.hpp> // opencv general include file
#include <ml/ml.hpp> // opencv machine learning include file
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT KNearestNeighborsMachineLearningModel
: public MachineLearningModel <TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef KNearestNeighborsMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(KNearestNeighborsMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(char * filename, const char * name=0);
/** Load the model from file */
virtual void Load(char * filename, const char * name=0);
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*);
/** Determine the file type. Returns true if this ImageIO can write the
* file specified. */
virtual bool CanWriteFile(const char*);
protected:
/** Constructor */
KNearestNeighborsMachineLearningModel();
/** Destructor */
virtual ~KNearestNeighborsMachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
private:
KNearestNeighborsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
CvKNearest * m_KNearestModel;
int m_K;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbKNearestNeighborsMachineLearningModel.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbKNearestNeighborsMachineLearningModel_txx
#define __otbKNearestNeighborsMachineLearningModel_txx
#include "otbKNearestNeighborsMachineLearningModel.h"
#include "otbOpenCVUtils.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::KNearestNeighborsMachineLearningModel()
{
m_KNearestModel = new CvKNearest;
m_K = 10;
}
template <class TInputValue, class TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::~KNearestNeighborsMachineLearningModel()
{
delete m_KNearestModel;
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
cv::Mat labels;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
//train the KNN model
m_KNearestModel->train(samples,labels,cv::Mat(),false, m_K,false);
}
template <class TInputValue, class TOutputValue>
typename KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
otb::SampleToMat<InputSampleType>(input,sample);
double result = m_KNearestModel->find_nearest(sample,m_K);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(result);
return target;
}
template <class TInputValue, class TOutputValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::Save(char * filename, const char * name)
{
m_KNearestModel->save(filename, name);
}
template <class TInputValue, class TOutputValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::Load(char * filename, const char * name)
{
m_KNearestModel->load(filename, name);
}
template <class TInputValue, class TOutputValue>
bool
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
bool
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
} //end namespace otb
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbKNearestNeighborsMachineLearningModelFactory_h
#define __otbKNearestNeighborsMachineLearningModelFactory_h
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{
/** \class KNearestNeighborsMachineLearningModelFactory
* \brief Creation d'un instance d'un objet KNearestNeighborsMachineLearningModel utilisant les object factory.
*/
template <class TInputValue, class TTargetValue>
class ITK_EXPORT KNearestNeighborsMachineLearningModelFactory : public itk::ObjectFactoryBase
{
public:
/** Standard class typedefs. */
typedef KNearestNeighborsMachineLearningModelFactory Self;
typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Class methods used to interface with the registered factories. */
virtual const char* GetITKSourceVersion(void) const;
virtual const char* GetDescription(void) const;
/** Method for class instantiation. */
itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(KNearestNeighborsMachineLearningModelFactory, itk::ObjectFactoryBase);
/** Register one factory of this type */
static void RegisterOneFactory(void)
{
KNearestNeighborsMachineLearningModelFactory::Pointer KNNFactory = KNearestNeighborsMachineLearningModelFactory::New();
itk::ObjectFactoryBase::RegisterFactory(KNNFactory);
}
protected:
KNearestNeighborsMachineLearningModelFactory();
virtual ~KNearestNeighborsMachineLearningModelFactory();
private:
KNearestNeighborsMachineLearningModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbKNearestNeighborsMachineLearningModelFactory.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbKNearestNeighborsMachineLearningModelFactory.h"
#include "itkCreateObjectFunction.h"
#include "otbKNearestNeighborsMachineLearningModel.h"
#include "itkVersion.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
KNearestNeighborsMachineLearningModelFactory<TInputValue,TOutputValue>
::KNearestNeighborsMachineLearningModelFactory()
{
static std::string classOverride = std::string("otbMachineLearningModel");
static std::string subclass = std::string("otbKNearestNeighborsMachineLearningModel");
this->RegisterOverride(classOverride.c_str(),
subclass.c_str(),
"KNN ML Model",
1,
itk::CreateObjectFunction<KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue> >::New());
}
template <class TInputValue, class TOutputValue>
KNearestNeighborsMachineLearningModelFactory<TInputValue,TOutputValue>
::~KNearestNeighborsMachineLearningModelFactory()
{
}
template <class TInputValue, class TOutputValue>
const char*
KNearestNeighborsMachineLearningModelFactory<TInputValue,TOutputValue>
::GetITKSourceVersion(void) const
{
return ITK_SOURCE_VERSION;
}
template <class TInputValue, class TOutputValue>
const char*
KNearestNeighborsMachineLearningModelFactory<TInputValue,TOutputValue>
::GetDescription() const
{
return "KNN machine learning model factory";
}
} // end namespace otb
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbLibSVMMachineLearningModel_h
#define __otbLibSVMMachineLearningModel_h
#include "itkLightObject.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "itkListSample.h"
#include "otbMachineLearningModel.h"
//include opencv
//#include <opencv.hpp> // opencv general include file
//#include <ml/ml.hpp> // opencv machine learning include file
// SVM estimator
#include "otbSVMSampleListModelEstimator.h"
// Validation
#include "otbSVMClassifier.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT LibSVMMachineLearningModel
: public MachineLearningModel <TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef LibSVMMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
// LibSVM related typedefs
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<InputSampleType> MeasurementVectorFunctorType;
typedef otb::SVMSampleListModelEstimator<InputListSampleType, TargetListSampleType, MeasurementVectorFunctorType>
SVMEstimatorType;
typedef otb::SVMClassifier<InputSampleType, TargetValueType> ClassifierType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(char * filename, const char * name=0);
/** Load the model from file */
virtual void Load(char * filename, const char * name=0);
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*);
/** Determine the file type. Returns true if this ImageIO can write the
* file specified. */
virtual bool CanWriteFile(const char*);
//Setters/Getters to SVM model
// itkGetMacro(SVMType, int);
// itkSetMacro(SVMType, int);
itkGetMacro(KernelType, int);
itkSetMacro(KernelType, int);
itkGetMacro(C, float);
itkSetMacro(C, float);
itkGetMacro(ParameterOptimization, bool);
itkSetMacro(ParameterOptimization, bool);
// itkGetMacro(Epsilon, int);
// itkSetMacro(Epsilon, int);
protected:
/** Constructor */
LibSVMMachineLearningModel();
/** Destructor */
virtual ~LibSVMMachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
private:
LibSVMMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
int m_KernelType;
float m_C;
bool m_ParameterOptimization;
typename SVMEstimatorType::Pointer m_SVMestimator;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbLibSVMMachineLearningModel.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbLibSVMMachineLearningModel_txx
#define __otbLibSVMMachineLearningModel_txx
#include <fstream>
#include "otbLibSVMMachineLearningModel.h"
//#include "otbOpenCVUtils.h"
// SVM estimator
//#include "otbSVMSampleListModelEstimator.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::LibSVMMachineLearningModel()
{
// m_SVMModel = new CvSVM;
// m_SVMType = CvSVM::C_SVC;
m_KernelType = LINEAR;
// m_TermCriteriaType = CV_TERMCRIT_ITER;
m_C = 1.0;
// m_Epsilon = 1e-6;
m_ParameterOptimization = false;
m_SVMestimator = SVMEstimatorType::New();
}
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::~LibSVMMachineLearningModel()
{
//delete m_SVMModel;
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
// Set up SVM's parameters
// CvSVMParams params;
// params.svm_type = m_SVMType;
// params.kernel_type = m_KernelType;
// params.term_crit = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
// // Train the SVM
m_SVMestimator->SetC(m_C);
m_SVMestimator->SetKernelType(m_KernelType);
m_SVMestimator->SetParametersOptimization(m_ParameterOptimization);
m_SVMestimator->SetInputSampleList(this->GetInputListSample());
m_SVMestimator->SetTrainingSampleList(this->GetTargetListSample());
m_SVMestimator->Update();
}
template <class TInputValue, class TOutputValue>
typename LibSVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
{
TargetSampleType target;
otbMsgDevMacro(<< "Starting iterations ");
MeasurementVectorFunctorType mfunctor;
target = m_SVMestimator->GetModel()->EvaluateLabel(mfunctor(input));
return target;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Save(char * filename, const char * name)
{
m_SVMestimator->GetModel()->SaveModel(filename);
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Load(char * filename, const char * name)
{
m_SVMestimator->GetModel()->LoadModel(filename);
}
template <class TInputValue, class TOutputValue>
bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const char * file)
{
//TODO: Rework.
std::ifstream ifs;
ifs.open(file);
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
//Read only the first line.
std::string line;
std::getline(ifs, line);
//if (line.find(m_SVMModel->getName()) != std::string::npos)
if (line.find("svm_type") != std::string::npos)
{
std::cout<<"Reading a libSVM model !!!"<<std::endl;
return true;
}
ifs.close();
return false;
}
template <class TInputValue, class TOutputValue>
bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const char * file)
{
return false;
}
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
} //end namespace otb
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbLibSVMMachineLearningModelFactory_h
#define __otbLibSVMMachineLearningModelFactory_h
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{
/** \class LibSVMMachineLearningModelFactory
* \brief Creation d'un instance d'un objet SVMMachineLearningModel utilisant les object factory.
*/
template <class TInputValue, class TTargetValue>
class ITK_EXPORT LibSVMMachineLearningModelFactory : public itk::ObjectFactoryBase
{
public:
/** Standard class typedefs. */
typedef LibSVMMachineLearningModelFactory Self;
typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Class methods used to interface with the registered factories. */
virtual const char* GetITKSourceVersion(void) const;
virtual const char* GetDescription(void) const;
/** Method for class instantiation. */
itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(LibSVMMachineLearningModelFactory, itk::ObjectFactoryBase);
/** Register one factory of this type */
static void RegisterOneFactory(void)
{
LibSVMMachineLearningModelFactory::Pointer LibSVMFactory = LibSVMMachineLearningModelFactory::New();
itk::ObjectFactoryBase::RegisterFactory(LibSVMFactory);
}
protected:
LibSVMMachineLearningModelFactory();
virtual ~LibSVMMachineLearningModelFactory();
private:
LibSVMMachineLearningModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbLibSVMMachineLearningModelFactory.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbLibSVMMachineLearningModelFactory.h"
#include "itkCreateObjectFunction.h"
#include "otbLibSVMMachineLearningModel.h"
#include "itkVersion.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>
::LibSVMMachineLearningModelFactory()
{
static std::string classOverride = std::string("otbMachineLearningModel");
static std::string subclass = std::string("otbLibSVMMachineLearningModel");
this->RegisterOverride(classOverride.c_str(),
subclass.c_str(),
"LibSVM ML Model",
1,
itk::CreateObjectFunction<LibSVMMachineLearningModel<TInputValue,TOutputValue> >::New());
}
template <class TInputValue, class TOutputValue>
LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>
::~LibSVMMachineLearningModelFactory()
{
}
template <class TInputValue, class TOutputValue>
const char*
LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>
::GetITKSourceVersion(void) const
{
return ITK_SOURCE_VERSION;
}
template <class TInputValue, class TOutputValue>
const char*
LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>
::GetDescription() const
{
return "LibSVM machine learning model factory";
}
} // end namespace otb
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbMachineLearningModel_h
#define __otbMachineLearningModel_h
#include "itkObject.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "itkListSample.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT MachineLearningModel
: public itk::Object
{
public:
/** Standard class typedefs. */
typedef MachineLearningModel Self;
typedef itk::Object Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
/** Run-time type information (and related methods). */
itkTypeMacro(MachineLearningModel, itk::Object);
/** Train the machine learning model */
virtual void Train() = 0;
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType& input) const = 0;
void PredictAll();
/** Save the model to file */
virtual void Save(char * filename, const char * name=0) = 0;
/** Load the model from file */
virtual void Load(char * filename, const char * name=0) = 0;
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*) = 0;
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanWriteFile(const char*) = 0;
/** Input accessors */
itkSetObjectMacro(InputListSample,InputListSampleType);
itkGetObjectMacro(InputListSample,InputListSampleType);
/** Target accessors */
itkSetObjectMacro(TargetListSample,TargetListSampleType);
itkGetObjectMacro(TargetListSample,TargetListSampleType);
protected:
/** Constructor */
MachineLearningModel();
/** Destructor */
virtual ~MachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Input list sample */
typename InputListSampleType::Pointer m_InputListSample;
/** Target list sample */
typename TargetListSampleType::Pointer m_TargetListSample;
private:
MachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbMachineLearningModel.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbMachineLearningModel_txx
#define __otbMachineLearningModel_txx
#include "otbMachineLearningModel.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
MachineLearningModel<TInputValue,TOutputValue>
::MachineLearningModel()
{}
template <class TInputValue, class TOutputValue>
MachineLearningModel<TInputValue,TOutputValue>
::~MachineLearningModel()
{}
template <class TInputValue, class TOutputValue>
void
MachineLearningModel<TInputValue,TOutputValue>
::PredictAll()
{
TargetListSampleType * targets = this->GetTargetListSample();
targets->Clear();
for(typename InputListSampleType::ConstIterator sIt = this->GetInputListSample()->Begin();
sIt!=this->GetInputListSample()->End();++sIt)
{
targets->PushBack(this->Predict(sIt.GetMeasurementVector()));
}
}
template <class TInputValue, class TOutputValue>
void
MachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
}
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbMachineLearningModelFactory_h
#define __otbMachineLearningModelFactory_h
#include "itkObject.h"
#include "otbMachineLearningModel.h"
namespace otb
{
/** \class MachineLearningModelFactory
* \brief Creation of object instance using object factory.
*/
template <class TInputValue, class TOutputValue>
class ITK_EXPORT MachineLearningModelFactory : public itk::Object
{
public:
/** Standard class typedefs. */
typedef MachineLearningModelFactory Self;
typedef itk::Object Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Class Methods used to interface with the registered factories */
/** Run-time type information (and related methods). */
itkTypeMacro(MachineLearningModelFactory, itk::Object);
/** Convenient typedefs. */
typedef typename otb::MachineLearningModel<TInputValue,TOutputValue>::Pointer MachineLearningModelTypePointer;
/** Mode in which the files is intended to be used */
typedef enum { ReadMode, WriteMode } FileModeType;
/** Create the appropriate MachineLearningModel depending on the particulars of the file. */
static MachineLearningModelTypePointer CreateMachineLearningModel(const char* path, FileModeType mode);
/** Register Built-in factories */
static void RegisterBuiltInFactories();
protected:
MachineLearningModelFactory();
~MachineLearningModelFactory();
private:
MachineLearningModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbMachineLearningModelFactory.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbMachineLearningModelFactory.h"
#include "itkMutexLock.h"
#include "itkMutexLockHolder.h"
//#include "otbKNearestNeighborsMachineLearningModelFactory.h"
#include "otbRandomForestsMachineLearningModelFactory.h"
#include "otbSVMMachineLearningModelFactory.h"
#include "otbLibSVMMachineLearningModelFactory.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
typename MachineLearningModel<TInputValue,TOutputValue>::Pointer
MachineLearningModelFactory<TInputValue,TOutputValue>
::CreateMachineLearningModel(const char* path, FileModeType mode)
{
RegisterBuiltInFactories();
std::list<MachineLearningModelTypePointer> possibleMachineLearningModel;
std::list<LightObject::Pointer> allobjects =
itk::ObjectFactoryBase::CreateAllInstance("otbMachineLearningModel");
for(std::list<LightObject::Pointer>::iterator i = allobjects.begin();
i != allobjects.end(); ++i)
{
MachineLearningModel<TInputValue,TOutputValue> * io = dynamic_cast<MachineLearningModel<TInputValue,TOutputValue>*>(i->GetPointer());
if(io)
{
possibleMachineLearningModel.push_back(io);
}
else
{
std::cerr << "Error MachineLearningModel Factory did not return an MachineLearningModel: "
<< (*i)->GetNameOfClass()
<< std::endl;
}
}
for(typename std::list<MachineLearningModelTypePointer>::iterator k = possibleMachineLearningModel.begin();
k != possibleMachineLearningModel.end(); ++k)
{
if( mode == ReadMode )
{
if((*k)->CanReadFile(path))
{
return *k;
}
}
else if( mode == WriteMode )
{
if((*k)->CanWriteFile(path))
{
return *k;
}
}
}
return 0;
}
template <class TInputValue, class TOutputValue>
void
MachineLearningModelFactory<TInputValue,TOutputValue>
::RegisterBuiltInFactories()
{
static bool firstTime = true;
static itk::SimpleMutexLock mutex;
{
// This helper class makes sure the Mutex is unlocked
// in the event an exception is thrown.
itk::MutexLockHolder<itk::SimpleMutexLock> mutexHolder(mutex);
if (firstTime)
{
// KNN Format for OTB
//itk::ObjectFactoryBase::RegisterFactory(KNearestNeighborsMachineLearningModelFactory<TInputValue,TOutputValue>::New());
itk::ObjectFactoryBase::RegisterFactory(RandomForestsMachineLearningModelFactory<TInputValue,TOutputValue>::New());
itk::ObjectFactoryBase::RegisterFactory(LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>::New());
itk::ObjectFactoryBase::RegisterFactory(SVMMachineLearningModelFactory<TInputValue,TOutputValue>::New());
firstTime = false;
}
}
}
} // end namespace otb
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbMachineLearningUtils.h"
#include <fstream>
#include <string>
#include <algorithm>
bool ReadDataFile(const char * infname, InputListSampleType * samples, TargetListSampleType * labels)
{
std::ifstream ifs;
ifs.open(infname);
if(!ifs)
{
std::cerr<<"Could not read file "<<infname<<std::endl;
return false;
}
unsigned int nbfeatures = 0;
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
if(nbfeatures == 0)
{
nbfeatures = std::count(line.begin(),line.end(),' ')-1;
std::cout<<"Found "<<nbfeatures<<" features per samples"<<std::endl;
}
if(line.size()>1)
{
InputSampleType sample(nbfeatures);
sample.Fill(0);
std::string::size_type pos = line.find_first_of(" ", 0);
// Parse label
TargetSampleType label;
label[0] = atoi(line.substr(0, pos).c_str());
bool endOfLine = false;
unsigned int id = 0;
while(!endOfLine)
{
std::string::size_type nextpos = line.find_first_of(" ", pos+1);
if(nextpos == std::string::npos)
{
endOfLine = true;
nextpos = line.size()-1;
}
else
{
std::string feature = line.substr(pos,nextpos-pos);
std::string::size_type semicolonpos = feature.find_first_of(":");
id = atoi(feature.substr(0,semicolonpos).c_str());
sample[id] = atof(feature.substr(semicolonpos+1,feature.size()-semicolonpos).c_str());
pos = nextpos;
}
}
samples->PushBack(sample);
labels->PushBack(label);
}
}
std::cout<<"Retrieved "<<samples->Size()<<" samples"<<std::endl;
ifs.close();
return true;
}
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbMachineLearningUtils_h
#define __otbMachineLearningUtils_h
#include "otbMachineLearningModel.h"
typedef otb::MachineLearningModel<float,short> MachineLearningModelType;
typedef MachineLearningModelType::InputValueType InputValueType;
typedef MachineLearningModelType::InputSampleType InputSampleType;
typedef MachineLearningModelType::InputListSampleType InputListSampleType;
typedef MachineLearningModelType::TargetValueType TargetValueType;
typedef MachineLearningModelType::TargetSampleType TargetSampleType;
typedef MachineLearningModelType::TargetListSampleType TargetListSampleType;
bool ReadDataFile(const char * infname, InputListSampleType * samples, TargetListSampleType * labels);
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbOpenCVUtils_h
#define __otbOpenCVUtils_h
#include "ml/ml.hpp"
#include "itkPixelBuilder.h"
namespace otb
{
template <class T> void SampleToMat(const T & sample, cv::Mat& output)
{
output.create(1,sample.Size(),CV_32FC1);
// Loop on sample size
for(unsigned int i = 0; i < sample.Size();++i)
{
output.at<float>(0,i) = sample[i];
}
}
/** Converts a ListSample of VariableLengthVector to a CvMat. The user
* is responsible for freeing the output pointer with the
* cvReleaseMat function. A null pointer is resturned in case the
* conversion failed.
*/
template <class T> void ListSampleToMat(const T * listSample, cv::Mat & output) {
// Sample index
unsigned int sampleIdx = 0;
// Check for valid listSample
if(listSample != NULL && listSample->Size() > 0)
{
// Retrieve samples count
unsigned int sampleCount = listSample->Size();
// Build an iterator
typename T::ConstIterator sampleIt = listSample->Begin();
// Retrieve samples size alike
const unsigned int sampleSize = listSample->GetMeasurementVectorSize();
// Allocate CvMat
output.create(sampleCount,sampleSize,CV_32FC1);
// Fill the cv matrix
for(;sampleIt!=listSample->End();++sampleIt,++sampleIdx)
{
// Retrieve sample
typename T::MeasurementVectorType sample = sampleIt.GetMeasurementVector();
// Loop on sample size
for(unsigned int i = 0; i < sampleSize;++i)
{
output.at<float>(sampleIdx,i) = sample[i];
}
}
}
}
template <typename T> void ListSampleToMat(typename T::Pointer listSample, cv::Mat & output) {
return ListSampleToMat(listSample.GetPointer(), output);
}
template <typename T> void ListSampleToMat(typename T::ConstPointer listSample, cv::Mat & output ) {
return ListSampleToMat(listSample.GetPointer(), output);
}
template <typename T> typename T::Pointer MatToListSample(const cv::Mat & cvmat)
{
// Build output type
typename T::Pointer output = T::New();
// Get samples count
unsigned sampleCount = cvmat.rows;
// Get samples size
unsigned int sampleSize = cvmat.cols;
// Loop on samples
for(unsigned int i = 0; i < sampleCount;++i)
{
typename T::MeasurementVectorType sample;
itk::PixelBuilder<typename T::MeasurementVectorType>::Zero(sample,sampleSize);
unsigned int realSampleSize = sample.Size();
for(unsigned int j = 0; j < realSampleSize;++j)
{
// Don't forget to cast
sample[j] = static_cast<typename T::MeasurementVectorType
::ValueType>(cvmat.at<float>(i,j));
}
// PushBack the new sample
output->PushBack(sample);
}
// return the output
return output;
}
}
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbRandomForestsMachineLearningModel_h
#define __otbRandomForestsMachineLearningModel_h
#include "itkLightObject.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "itkListSample.h"
#include "otbMachineLearningModel.h"
//include opencv
#include <opencv.hpp> // opencv general include file
#include <ml/ml.hpp> // opencv machine learning include file
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT RandomForestsMachineLearningModel
: public MachineLearningModel <TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef RandomForestsMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input related typedefs
typedef TInputValue InputValueType;
typedef itk::VariableLengthVector<InputValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType;
// Target related typedefs
typedef TTargetValue TargetValueType;
typedef itk::FixedArray<TargetValueType,1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType;
//opencv typedef
typedef CvRTrees RFType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(RandomForestsMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(char * filename, const char * name=0);
/** Load the model from file */
virtual void Load(char * filename, const char * name=0);
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*);
/** Determine the file type. Returns true if this ImageIO can write the
* file specified. */
virtual bool CanWriteFile(const char*);
/* /\** Input accessors *\/ */
/* itkSetObjectMacro(InputListSample,InputListSampleType); */
/* itkGetObjectMacro(InputListSample,InputListSampleType); */
/* /\** Target accessors *\/ */
/* itkSetObjectMacro(TargetListSample,TargetListSampleType); */
/* itkGetObjectMacro(TargetListSample,TargetListSampleType); */
//Setters of RT parameters (documentation get from opencv doxygen 2.4)
/* the depth of the tree. A low value will likely underfit and conversely a
* high value will likely overfit. The optimal value can be obtained using cross
* validation or other suitable methods. */
itkGetMacro(MaxDepth, int);
itkSetMacro(MaxDepth, int);
/* minimum samples required at a leaf node for it to be split. A reasonable
* value is a small percentage of the total data e.g. 1%. */
itkGetMacro(MinSampleCount, int);
itkSetMacro(MinSampleCount, int);
/* Termination criteria for regression trees. If all absolute differences
* between an estimated value in a node and values of train samples in this node
* are less than this parameter then the node will not be split */
itkGetMacro(RegressionAccuracy, double);
itkSetMacro(RegressionAccuracy, bool);
itkGetMacro(ComputeSurrogateSplit, bool);
itkSetMacro(ComputeSurrogateSplit, bool);
/* Cluster possible values of a categorical variable into K \leq
* max_categories clusters to find a suboptimal split. If a discrete variable,
* on which the training procedure tries to make a split, takes more than
* max_categories values, the precise best subset estimation may take a very
* long time because the algorithm is exponential. Instead, many decision
* trees engines (including ML) try to find sub-optimal split in this case by
* clustering all the samples into max categories clusters that is some
* categories are merged together. The clustering is applied only in n>2-class
* classification problems for categorical variables with N > max_categories
* possible values. In case of regression and 2-class classification the
* optimal split can be found efficiently without employing clustering, thus
* the parameter is not used in these cases.
*/
itkGetMacro(MaxNumberOfCategories, int);
itkSetMacro(MaxNumberOfCategories, int);
/* The array of a priori class probabilities, sorted by the class label
* value. The parameter can be used to tune the decision tree preferences toward
* a certain class. For example, if you want to detect some rare anomaly
* occurrence, the training base will likely contain much more normal cases than
* anomalies, so a very good classification performance will be achieved just by
* considering every case as normal. To avoid this, the priors can be specified,
* where the anomaly probability is artificially increased (up to 0.5 or even
* greater), so the weight of the misclassified anomalies becomes much bigger,
* and the tree is adjusted properly. You can also think about this parameter as
* weights of prediction categories which determine relative weights that you
* give to misclassification. That is, if the weight of the first category is 1
* and the weight of the second category is 10, then each mistake in predicting
* the second category is equivalent to making 10 mistakes in predicting the
first category. */
std::vector<float> GetPriors() const
{
return m_Priors;
}
void SetPriors(const std::vector<float> & priors)
{
m_Priors = priors;
}
/* If true then variable importance will be calculated and then it can be
retrieved by CvRTrees::get_var_importance(). */
itkGetMacro(CalculateVariableImportance, int);
itkSetMacro(CalculateVariableImportance, int);
/* The size of the randomly selected subset of features at each tree node and
* that are used to find the best split(s). If you set it to 0 then the size will
be set to the square root of the total number of features. */
itkGetMacro(MaxNumberOfVariables, int);
itkSetMacro(MaxNumberOfVariables, int);
/* The maximum number of trees in the forest (surprise, surprise). Typically
* the more trees you have the better the accuracy. However, the improvement in
* accuracy generally diminishes and asymptotes pass a certain number of
* trees. Also to keep in mind, the number of tree increases the prediction time
linearly. */
itkGetMacro(MaxNumberOfTrees, int);
itkSetMacro(MaxNumberOfTrees, int);
/* Sufficient accuracy (OOB error) */
itkGetMacro(ForestAccuracy, float);
itkSetMacro(ForestAccuracy, float);
/* The type of the termination criteria */
itkGetMacro(TerminationCriteria, int);
itkSetMacro(TerminationCriteria, int);
// cv::Mat GetVariableImportance()
// {
// return m_RFModel->getVarImportance();
// }
float GetTrainError()
{
return m_RFModel->get_train_error();
}
protected:
/** Constructor */
RandomForestsMachineLearningModel();
/** Destructor */
virtual ~RandomForestsMachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/* /\** Input list sample *\/ */
/* typename InputListSampleType::Pointer m_InputListSample; */
/* /\** Target list sample *\/ */
/* typename TargetListSampleType::Pointer m_TargetListSample; */
private:
RandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
CvRTrees * m_RFModel;
int m_MaxDepth;
int m_MinSampleCount;
float m_RegressionAccuracy;
bool m_ComputeSurrogateSplit;
int m_MaxNumberOfCategories;
std::vector<float> m_Priors;
bool m_CalculateVariableImportance;
int m_MaxNumberOfVariables;
int m_MaxNumberOfTrees;
float m_ForestAccuracy;
int m_TerminationCriteria;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbRandomForestsMachineLearningModel.txx"
#endif
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment