Commit 3527c8c1 authored by Arnaud Jaen's avatar Arnaud Jaen

ENH: Add BoostMachineLearningModel functionnalities.

parent e7aa93bf
......@@ -53,6 +53,47 @@ public:
itkNewMacro(Self);
itkTypeMacro(BoostMachineLearningModel, itk::MachineLearningModel);
/** Setters/Getters to the Boost type
* It can be CvBoost::DISCRETE, CvBoost::REAL, CvBoost::LOGIT, CvBoost::GENTLE
* Default is CvBoost::REAL.
* see @http://docs.opencv.org/modules/ml/doc/boosting.html#cvboostparams-cvboostparams
*/
itkGetMacro(BoostType, int);
itkSetMacro(BoostType, int);
/** Setters/Getters to the split criteria
* It can be CvBoost::DEFAULT, CvBoost::GINI, CvBoost::MISCLASS, CvBoost::SQERR
* Default is CvBoost::DEFAULT. It uses default value according to \c BoostType
* see @http://docs.opencv.org/modules/ml/doc/boosting.html#cvboost-predict
*/
itkGetMacro(SplitCrit, int);
itkSetMacro(SplitCrit, int);
/** Setters/Getters to the number of weak classifiers.
* Default is 100.
* see @http://docs.opencv.org/modules/ml/doc/boosting.html#cvboostparams-cvboostparams
*/
itkGetMacro(WeakCount, int);
itkSetMacro(WeakCount, int);
/** Setters/Getters to the threshold WeightTrimRate.
* A threshold between 0 and 1 used to save computational time.
* Samples with summary weight \leq 1 - WeightTrimRate do not participate in the next iteration of training.
* Set this parameter to 0 to turn off this functionality.
* Default is 0.95
* see @http://docs.opencv.org/modules/ml/doc/boosting.html#cvboostparams-cvboostparams
*/
itkGetMacro(WeightTrimRate, double);
itkSetMacro(WeightTrimRate, double);
/** Setters/Getters to the maximum depth of the tree.
* Default is 1
* see @http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29
*/
itkGetMacro(MaxDepth, int);
itkSetMacro(MaxDepth, int);
/** Train the machine learning model */
virtual void Train();
......@@ -60,18 +101,18 @@ public:
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(char * filename, const char * name=0);
virtual void Save(const std::string & filename, const std::string & name="");
/** Load the model from file */
virtual bool Load(char * filename, const char * name=0);
virtual void Load(const std::string & filename, const std::string & name="");
/** Determine the file type. Returns true if this ImageIO can read the
* file specified. */
virtual bool CanReadFile(const char*);
virtual bool CanReadFile(const std::string &);
/** Determine the file type. Returns true if this ImageIO can write the
* file specified. */
virtual bool CanWriteFile(const char*);
virtual bool CanWriteFile(const std::string &);
protected:
/** Constructor */
......@@ -88,6 +129,11 @@ private:
void operator =(const Self&); //purposely not implemented
CvBoost * m_BoostModel;
int m_BoostType;
int m_SplitCrit;
int m_WeakCount;
double m_WeightTrimRate;
int m_MaxDepth;
};
} // end namespace otb
......
......@@ -26,7 +26,9 @@ namespace otb
template <class TInputValue, class TOutputValue>
BoostMachineLearningModel<TInputValue,TOutputValue>
::BoostMachineLearningModel()
::BoostMachineLearningModel() :
m_BoostType(CvBoost::REAL), m_SplitCrit(CvBoost::DEFAULT), m_WeakCount(100),
m_WeightTrimRate(0.95), m_MaxDepth(1)
{
m_BoostModel = new CvBoost;
}
......@@ -52,9 +54,8 @@ BoostMachineLearningModel<TInputValue,TOutputValue>
cv::Mat labels;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
CvBoostParams params;
params.boost_type = CvBoost::DISCRETE;
params.split_criteria = CvBoost::DEFAULT;
CvBoostParams params = CvBoostParams(m_BoostType, m_WeakCount, m_WeightTrimRate, m_MaxDepth, false, 0);
params.split_criteria = m_SplitCrit;
//train the Boost model
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
......@@ -89,31 +90,59 @@ BoostMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::Save(char * filename, const char * name)
::Save(const std::string & filename, const std::string & name)
{
m_BoostModel->save(filename, name);
if (name == "")
m_BoostModel->save(filename.c_str(), 0);
else
m_BoostModel->save(filename.c_str(), name.c_str());
}
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::Load(char * filename, const char * name)
::Load(const std::string & filename, const std::string & name)
{
m_BoostModel->load(filename, name);
if (name == "")
m_BoostModel->load(filename.c_str(), 0);
else
m_BoostModel->load(filename.c_str(), name.c_str());
}
template <class TInputValue, class TOutputValue>
bool
BoostMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const char * file)
::CanReadFile(const std::string & file)
{
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
//if (line.find(m_SVMModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_BOOSTING) != std::string::npos)
{
std::cout<<"Reading a "<<CV_TYPE_NAME_ML_BOOSTING<<" model !!!"<<std::endl;
return true;
}
}
ifs.close();
return false;
}
template <class TInputValue, class TOutputValue>
bool
BoostMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const char * file)
::CanWriteFile(const std::string & file)
{
return false;
}
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbBoostMachineLearningModelFactory_h
#define __otbBoostMachineLearningModelFactory_h
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{
/** \class BoostMachineLearningModelFactory
* \brief Creation d'un instance d'un objet SVMMachineLearningModel utilisant les object factory.
*/
template <class TInputValue, class TTargetValue>
class ITK_EXPORT BoostMachineLearningModelFactory : public itk::ObjectFactoryBase
{
public:
/** Standard class typedefs. */
typedef BoostMachineLearningModelFactory Self;
typedef itk::ObjectFactoryBase Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Class methods used to interface with the registered factories. */
virtual const char* GetITKSourceVersion(void) const;
virtual const char* GetDescription(void) const;
/** Method for class instantiation. */
itkFactorylessNewMacro(Self);
/** Run-time type information (and related methods). */
itkTypeMacro(BoostMachineLearningModelFactory, itk::ObjectFactoryBase);
/** Register one factory of this type */
static void RegisterOneFactory(void)
{
BoostMachineLearningModelFactory::Pointer Factory = BoostMachineLearningModelFactory::New();
itk::ObjectFactoryBase::RegisterFactory(Factory);
}
protected:
BoostMachineLearningModelFactory();
virtual ~BoostMachineLearningModelFactory();
private:
BoostMachineLearningModelFactory(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbBoostMachineLearningModelFactory.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbBoostMachineLearningModelFactory.h"
#include "itkCreateObjectFunction.h"
#include "otbBoostMachineLearningModel.h"
#include "itkVersion.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
BoostMachineLearningModelFactory<TInputValue,TOutputValue>
::BoostMachineLearningModelFactory()
{
static std::string classOverride = std::string("otbMachineLearningModel");
static std::string subclass = std::string("otbBoostMachineLearningModel");
this->RegisterOverride(classOverride.c_str(),
subclass.c_str(),
"Boost ML Model",
1,
itk::CreateObjectFunction<BoostMachineLearningModel<TInputValue,TOutputValue> >::New());
}
template <class TInputValue, class TOutputValue>
BoostMachineLearningModelFactory<TInputValue,TOutputValue>
::~BoostMachineLearningModelFactory()
{
}
template <class TInputValue, class TOutputValue>
const char*
BoostMachineLearningModelFactory<TInputValue,TOutputValue>
::GetITKSourceVersion(void) const
{
return ITK_SOURCE_VERSION;
}
template <class TInputValue, class TOutputValue>
const char*
BoostMachineLearningModelFactory<TInputValue,TOutputValue>
::GetDescription() const
{
return "Boost machine learning model factory";
}
} // end namespace otb
......@@ -23,6 +23,7 @@
#include "otbRandomForestsMachineLearningModelFactory.h"
#include "otbSVMMachineLearningModelFactory.h"
#include "otbLibSVMMachineLearningModelFactory.h"
#include "otbBoostMachineLearningModelFactory.h"
namespace otb
......@@ -95,6 +96,7 @@ MachineLearningModelFactory<TInputValue,TOutputValue>
itk::ObjectFactoryBase::RegisterFactory(RandomForestsMachineLearningModelFactory<TInputValue,TOutputValue>::New());
itk::ObjectFactoryBase::RegisterFactory(LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>::New());
itk::ObjectFactoryBase::RegisterFactory(SVMMachineLearningModelFactory<TInputValue,TOutputValue>::New());
itk::ObjectFactoryBase::RegisterFactory(BoostMachineLearningModelFactory<TInputValue,TOutputValue>::New());
firstTime = false;
}
......
......@@ -725,6 +725,19 @@ IF(OTB_USE_OPENCV)
${TEMP}/libsvm_model.txt
)
ADD_TEST(leTuBoostMachineLearningModelNew ${LEARNING_TESTS6}
otbBoostMachineLearningModelNew)
ADD_TEST(leTvBoostMachineLearningModel ${LEARNING_TESTS6}
#--compare-ascii ${NOTOL}
#${BASELINE_FILES}/BoostLearningModel.txt
#${TEMP}/BoostMachineLearningModel.txt
otbBoostMachineLearningModel
${INPUTDATA}/letter.scale
${TEMP}/BoostMachineLearningModel.txt
)
ADD_TEST(leTuImageClassificationFilterNew ${LEARNING_TESTS6}
otbImageClassificationFilterNew)
......@@ -755,7 +768,7 @@ IF(OTB_USE_OPENCV)
ADD_TEST(leTuSVMMachineLearningModelCanRead ${LEARNING_TESTS6}
otbSVMMachineLearningModelCanRead
${INPUTDATA}/svm_model_image
${INPUTDATA}/opencv_svm_model.txt
)
ADD_TEST(leTuRandomForestsMachineLearningModelCanRead ${LEARNING_TESTS6}
......@@ -764,6 +777,11 @@ IF(OTB_USE_OPENCV)
)
SET_TESTS_PROPERTIES(leTuRandomForestsMachineLearningModelCanRead
PROPERTIES DEPENDS leTvRandomForestsMachineLearningModel)
ADD_TEST(leTuBoostMachineLearningModelCanRead ${LEARNING_TESTS6}
otbBoostMachineLearningModelCanRead
${INPUTDATA}/boost_model.txt
)
ENDIF(OTB_USE_OPENCV)
......
......@@ -32,9 +32,12 @@ void RegisterTests()
REGISTER_TEST(otbKNearestNeighborsMachineLearningModel);
REGISTER_TEST(otbRandomForestsMachineLearningModelNew);
REGISTER_TEST(otbRandomForestsMachineLearningModel);
REGISTER_TEST(otbBoostMachineLearningModelNew);
REGISTER_TEST(otbBoostMachineLearningModel);
REGISTER_TEST(otbImageClassificationFilterNew);
REGISTER_TEST(otbImageClassificationFilter);
REGISTER_TEST(otbLibSVMMachineLearningModelCanRead);
REGISTER_TEST(otbSVMMachineLearningModelCanRead);
REGISTER_TEST(otbRandomForestsMachineLearningModelCanRead);
REGISTER_TEST(otbBoostMachineLearningModelCanRead);
}
......@@ -20,6 +20,7 @@
#include "otbLibSVMMachineLearningModel.h"
#include "otbSVMMachineLearningModel.h"
#include "otbRandomForestsMachineLearningModel.h"
#include "otbBoostMachineLearningModel.h"
#include <iostream>
typedef otb::MachineLearningModel<float,short> MachineLearningModelType;
......@@ -107,3 +108,30 @@ int otbRandomForestsMachineLearningModelCanRead(int argc, char* argv[])
return EXIT_SUCCESS;
}
int otbBoostMachineLearningModelCanRead(int argc, char* argv[])
{
if (argc != 2)
{
std::cerr << "Usage: " << argv[0]
<< "<model>" << std::endl;
std::cerr << "Called here with " << argc << " arguments\n";
for (int i = 1; i < argc; ++i)
{
std::cerr << " - " << argv[i] << "\n";
}
return EXIT_FAILURE;
}
std::string filename(argv[1]);
typedef otb::BoostMachineLearningModel<InputValueType, TargetValueType> BoostType;
BoostType::Pointer classifier = BoostType::New();
bool lCanRead = classifier->CanReadFile(filename);
if (lCanRead == false)
{
std::cerr << "Erreur otb::BoostMachineLearningModel : impossible to open the file " << filename << "." << std::endl;
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
......@@ -25,6 +25,7 @@
#include "otbSVMMachineLearningModel.h"
#include "otbKNearestNeighborsMachineLearningModel.h"
#include "otbRandomForestsMachineLearningModel.h"
#include "otbBoostMachineLearningModel.h"
#include "otbConfusionMatrixCalculator.h"
......@@ -175,7 +176,7 @@ int otbSVMMachineLearningModel(int argc, char * argv[])
return EXIT_FAILURE;
}
typedef otb::LibSVMMachineLearningModel<InputValueType, TargetValueType> SVMType;
typedef otb::SVMMachineLearningModel<InputValueType, TargetValueType> SVMType;
InputListSampleType::Pointer samples = InputListSampleType::New();
TargetListSampleType::Pointer labels = TargetListSampleType::New();
......@@ -359,3 +360,57 @@ int otbRandomForestsMachineLearningModel(int argc, char * argv[])
}
}
int otbBoostMachineLearningModelNew(int argc, char * argv[])
{
typedef otb::BoostMachineLearningModel<InputValueType,TargetValueType> BoostType;
BoostType::Pointer classifier = BoostType::New();
return EXIT_SUCCESS;
}
int otbBoostMachineLearningModel(int argc, char * argv[])
{
if (argc != 3 )
{
std::cout<<"Wrong number of arguments "<<std::endl;
std::cout<<"Usage : sample file, output file "<<std::endl;
return EXIT_FAILURE;
}
typedef otb::BoostMachineLearningModel<InputValueType, TargetValueType> BoostType;
InputListSampleType::Pointer samples = InputListSampleType::New();
TargetListSampleType::Pointer labels = TargetListSampleType::New();
TargetListSampleType::Pointer predicted = TargetListSampleType::New();
if(!ReadDataFile(argv[1],samples,labels))
{
std::cout<<"Failed to read samples file "<<argv[1]<<std::endl;
return EXIT_FAILURE;
}
BoostType::Pointer classifier = BoostType::New();
classifier->SetInputListSample(samples);
classifier->SetTargetListSample(labels);
classifier->Train();
classifier->SetTargetListSample(predicted);
classifier->PredictAll();
classifier->Save(argv[2]);
ConfusionMatrixCalculatorType::Pointer cmCalculator = ConfusionMatrixCalculatorType::New();
cmCalculator->SetProducedLabels(predicted);
cmCalculator->SetReferenceLabels(labels);
cmCalculator->Compute();
std::cout<<"Confusion matrix: "<<std::endl;
std::cout<<cmCalculator->GetConfusionMatrix()<<std::endl;
std::cout<<"Kappa: "<<cmCalculator->GetKappaIndex()<<std::endl;
std::cout<<"Overall Accuracy: "<<cmCalculator->GetOverallAccuracy()<<std::endl;
return EXIT_SUCCESS;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment