Commit 6a930411 authored by Jordi Inglada's avatar Jordi Inglada

ENH: make shark learning an independent module

parent d2e28d3c
project(OTBSharkLearning)
set(OTBSharkLearning_LIBRARIES
${OTBShark_LIBRARIES}
${OTBLearningBase_LIBRARIES}
${OTBCommon_LIBRARIES} )
otb_module_impl()
set(OTBAppShark_LINK_LIBS
${OTBVectorDataBase_LIBRARIES}
${OTBConversion_LIBRARIES}
${OTBStatistics_LIBRARIES}
${OTBColorMap_LIBRARIES}
${OTBBoost_LIBRARIES}
${OTBInterpolation_LIBRARIES}
${OTBVectorDataIO_LIBRARIES}
${OTBApplicationEngine_LIBRARIES}
${OTBIndices_LIBRARIES}
${OTBMathParser_LIBRARIES}
${OTBGdalAdapters_LIBRARIES}
${OTBProjection_LIBRARIES}
${OTBImageBase_LIBRARIES}
${OTBIOXML_LIBRARIES}
${OTBVectorDataManipulation_LIBRARIES}
${OTBStreaming_LIBRARIES}
${OTBImageManipulation_LIBRARIES}
${OTBObjectList_LIBRARIES}
${OTBCommon_LIBRARIES}
)
# otb_create_application(
# NAME SharkTrainImagesClassifier
# SOURCES otbSharkTrainImagesClassifier.cxx
# LINK_LIBRARIES ${${otb-module}_LIBRARIES} ${OTBAppShark_LINK_LIBS})
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbSharkLearningApplicationBase_h
#define __otbSharkLearningApplicationBase_h
#include "otbConfigure.h"
#include "otbWrapperApplication.h"
#include <iostream>
// ListSample
#include "itkListSample.h"
#include "itkVariableLengthVector.h"
//Estimator
#include "otbMachineLearningModelFactory.h"
#ifdef OTB_USE_SHARK
#include "otbSharkRandomForestsMachineLearningModel.h"
#endif
namespace otb
{
namespace Wrapper
{
/** \class SharkLearningApplicationBase
* \brief SharkLearningApplicationBase is the base class for application that
* use Shark machine learning models.
*
*/
template <class TInputValue, class TOutputValue>
class SharkLearningApplicationBase: public Application
{
public:
/** Standard class typedefs. */
typedef SharkLearningApplicationBase Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkTypeMacro(SharkLearningApplicationBase, otb::Application)
typedef TInputValue InputValueType;
typedef TOutputValue OutputValueType;
typedef otb::VectorImage<InputValueType> SampleImageType;
typedef typename SampleImageType::PixelType PixelType;
// Machine Learning models
typedef otb::MachineLearningModelFactory<
InputValueType, OutputValueType> ModelFactoryType;
typedef typename ModelFactoryType::MachineLearningModelTypePointer ModelPointerType;
typedef typename ModelFactoryType::MachineLearningModelType ModelType;
typedef typename ModelType::InputSampleType SampleType;
typedef typename ModelType::InputListSampleType ListSampleType;
typedef typename ModelType::TargetSampleType TargetSampleType;
typedef typename ModelType::TargetListSampleType TargetListSampleType;
typedef typename ModelType::TargetValueType TargetValueType;
#ifdef OTB_USE_SHARK
typedef otb::SharkRandomForestsMachineLearningModel<InputValueType, OutputValueType> SharkRandomForestType;
#endif
protected:
SharkLearningApplicationBase();
/** Generic method to train and save the machine learning model. This method
* uses specific train methods depending on the chosen model.*/
void Train(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
/** Generic method to load a model file and use it to classify a sample list*/
void Classify(typename ListSampleType::Pointer validationListSample,
typename TargetListSampleType::Pointer predictedList,
std::string modelPath);
/** Init method that creates all the parameters for machine learning models */
void DoInit();
/** Flag to switch between classification and regression mode.
* False by default, child classes may change it in their constructor */
bool m_RegressionFlag;
private:
/** Specific Init and Train methods for each machine learning model */
//@{
#ifdef OTB_USE_SHARK
void InitSharkRandomForestsParams();
void TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
#endif
//@}
};
}
}
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSharkLearningApplicationBase.txx"
#ifdef OTB_USE_SHARK
#include "otbTrainSharkRandomForests.txx"
#endif
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbSharkLearningApplicationBase_txx
#define __otbSharkLearningApplicationBase_txx
#include "otbSharkLearningApplicationBase.h"
// only need this filter as a dummy process object
#include "otbRGBAPixelConverter.h"
namespace otb
{
namespace Wrapper
{
template <class TInputValue, class TOutputValue>
SharkLearningApplicationBase<TInputValue,TOutputValue>
::SharkLearningApplicationBase() : m_RegressionFlag(false)
{
}
template <class TInputValue, class TOutputValue>
void
SharkLearningApplicationBase<TInputValue,TOutputValue>
::DoInit()
{
AddDocTag(Tags::Learning);
// main choice parameter that will contain all machine learning options
AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
#ifdef OTB_USE_SHARK
InitSharkRandomForestsParams();
#endif
}
template <class TInputValue, class TOutputValue>
void
SharkLearningApplicationBase<TInputValue,TOutputValue>
::Classify(typename ListSampleType::Pointer validationListSample,
typename TargetListSampleType::Pointer predictedList,
std::string modelPath)
{
// Setup fake reporter
RGBAPixelConverter<int,int>::Pointer dummyFilter =
RGBAPixelConverter<int,int>::New();
dummyFilter->SetProgress(0.0f);
this->AddProcess(dummyFilter,"Classify...");
dummyFilter->InvokeEvent(itk::StartEvent());
// load a machine learning model from file and predict the input sample list
ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath,
ModelFactoryType::ReadMode);
if (model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << modelPath);
}
model->Load(modelPath);
model->SetRegressionMode(this->m_RegressionFlag);
model->SetInputListSample(validationListSample);
model->SetTargetListSample(predictedList);
model->PredictAll();
// update reporter
dummyFilter->UpdateProgress(1.0f);
dummyFilter->InvokeEvent(itk::EndEvent());
}
template <class TInputValue, class TOutputValue>
void
SharkLearningApplicationBase<TInputValue,TOutputValue>
::Train(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath)
{
// Setup fake reporter
RGBAPixelConverter<int,int>::Pointer dummyFilter =
RGBAPixelConverter<int,int>::New();
dummyFilter->SetProgress(0.0f);
this->AddProcess(dummyFilter,"Training model...");
dummyFilter->InvokeEvent(itk::StartEvent());
// get the name of the chosen machine learning model
const std::string modelName = GetParameterString("classifier");
// call specific train function
if (modelName == "sharkrf")
{
#ifdef OTB_USE_SHARK
TrainSharkRandomForests(trainingListSample, trainingLabeledListSample, modelPath);
#else
otbAppLogFATAL("Module Shark is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
#endif
}
// update reporter
dummyFilter->UpdateProgress(1.0f);
dummyFilter->InvokeEvent(itk::EndEvent());
}
}
}
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbTrainSharkRandomForests_txx
#define __otbTrainSharkRandomForests_txx
#include "otbSharkLearningApplicationBase.h"
namespace otb
{
namespace Wrapper
{
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitSharkRandomForestsParams()
{
AddChoice("classifier.sharkrf", "Shark Random forests classifier");
SetParameterDescription("classifier.sharkrf",
"This group of parameters allows setting Shark Random Forests classifier parameters. "
"See complete documentation here \\url{http://image.diku.dk/shark/doxygen_pages/html/classshark_1_1_r_f_trainer.html}.");
//MaxNumberOfTrees
AddParameter(ParameterType_Int, "classifier.sharkrf.nbtrees",
"Maximum number of trees in the forest");
SetParameterInt("classifier.sharkrf.nbtrees", 100);
SetParameterDescription(
"classifier.sharkrf.nbtrees",
"The maximum number of trees in the forest. Typically, the more trees you have, the better the accuracy. "
"However, the improvement in accuracy generally diminishes and reaches an asymptote for a certain number of trees. "
"Also to keep in mind, increasing the number of trees increases the prediction time linearly.");
//NodeSize
AddParameter(ParameterType_Int, "classifier.sharkrf.nodesize", "Min size of the node for a split");
SetParameterInt("classifier.sharkrf.nodesize", 25);
SetParameterDescription(
"classifier.sharkrf.nodesize",
"If the number of samples in a node is smaller than this parameter, "
"then the node will not be split. A reasonable value is a small percentage of the total data e.g. 1 percent.");
//MTry
AddParameter(ParameterType_Int, "classifier.sharkrf.mtry", "Number of features tested at each node");
SetParameterInt("classifier.sharkrf.mtry", 0);
SetParameterDescription(
"classifier.sharkrf.mtry",
"The number of features (variables) which will be tested at each node in "
"order to compute the split. If set to zero, the square root of the number of "
"features is used.");
//OOB Ratio
AddParameter(ParameterType_Float, "classifier.sharkrf.oobr", "Out of bound ratio");
SetParameterFloat("classifier.sharkrf.oobr", 0.66);
SetParameterDescription("classifier.sharkrf.oobr",
"Set the fraction of the original training dataset to use as the out of bag sample."
"A good default value is 0.66. ");
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath)
{
typename SharkRandomForestType::Pointer classifier = SharkRandomForestType::New();
classifier->SetRegressionMode(this->m_RegressionFlag);
classifier->SetInputListSample(trainingListSample);
classifier->SetTargetListSample(trainingLabeledListSample);
classifier->SetNodeSize(GetParameterInt("classifier.sharkrf.nodesize"));
classifier->SetOobRatio(GetParameterFloat("classifier.sharkrf.oobr"));
classifier->SetNumberOfTrees(GetParameterInt("classifier.sharkrf.nbtrees"));
classifier->SetMTry(GetParameterInt("classifier.sharkrf.mtry"));
classifier->Train();
classifier->Save(modelPath);
}
} //end namespace wrapper
} //end namespace otb
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbSharkRandomForestsMachineLearningModel_h
#define __otbSharkRandomForestsMachineLearningModel_h
#include "itkLightObject.h"
#include "otbMachineLearningModel.h"
#include "shark/Algorithms/Trainers/RFTrainer.h"
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT SharkRandomForestsMachineLearningModel
: public MachineLearningModel <TInputValue, TTargetValue>
{
public:
/** Standard class typedefs. */
typedef SharkRandomForestsMachineLearningModel Self;
typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef typename Superclass::InputValueType InputValueType;
typedef typename Superclass::InputSampleType InputSampleType;
typedef typename Superclass::InputListSampleType InputListSampleType;
typedef typename Superclass::TargetValueType TargetValueType;
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
itkTypeMacro(SharkRandomForestsMachineLearningModel, MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
/** Load the model from file */
virtual void Load(const std::string & filename, const std::string & name="");
/** Classify all samples in InputListSample and fill TargetListSample with the associated label */
virtual void PredictAll() override;
/**\name Classification model file compatibility tests */
//@{
/** Is the input model file readable and compatible with the corresponding classifier ? */
virtual bool CanReadFile(const std::string &);
/** Is the input model file writable and compatible with the corresponding classifier ? */
virtual bool CanWriteFile(const std::string &);
//@}
itkGetMacro(NumberOfTrees,unsigned int);
itkSetMacro(NumberOfTrees,unsigned int);
itkGetMacro(MTry, unsigned int);
itkSetMacro(MTry, unsigned int);
itkGetMacro(NodeSize, unsigned int);
itkSetMacro(NodeSize, unsigned int);
itkGetMacro(OobRatio, float);
itkSetMacro(OobRatio, float);
itkGetMacro(ComputeMargin, bool);
itkSetMacro(ComputeMargin, bool);
protected:
/** Constructor */
SharkRandomForestsMachineLearningModel();
/** Destructor */
virtual ~SharkRandomForestsMachineLearningModel();
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
private:
SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
shark::RFClassifier m_RFModel;
shark::RFTrainer m_RFTrainer;
unsigned int m_NumberOfTrees;
unsigned int m_MTry;
unsigned int m_NodeSize;
float m_OobRatio;
bool m_ComputeMargin;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSharkRandomForestsMachineLearningModel.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbSharkRandomForestsMachineLearningModel_txx
#define __otbSharkRandomForestsMachineLearningModel_txx
#include <fstream>
#include "itkMacro.h"
#include "otbSharkRandomForestsMachineLearningModel.h"
#include <shark/Models/Converter.h>
#include "otbSharkUtils.h"
#include <algorithm>
namespace otb
{
template <class TInputValue, class TOutputValue>
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::SharkRandomForestsMachineLearningModel()
{
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = true;
}
template <class TInputValue, class TOutputValue>
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::~SharkRandomForestsMachineLearningModel()
{
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
std::vector<shark::RealVector> features;
std::vector<unsigned int> class_labels;
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
m_RFTrainer.setMTry(m_MTry);
m_RFTrainer.setNTrees(m_NumberOfTrees);
m_RFTrainer.setNodeSize(m_NodeSize);
m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.train(m_RFModel, TrainSamples);
}
template <class TInputValue, class TOutputValue>
typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & value, ConfidenceValueType *quality) const
{
shark::RealVector samples;
for(size_t i = 0; i < value.Size();i++)
{
samples.push_back(value[i]);
}
auto probas = m_RFModel(samples);
if (quality != NULL)
{
if(m_ComputeMargin)
{
std::nth_element(probas.begin(), probas.begin()+1,
probas.end(), std::greater<double>());
(*quality) = static_cast<ConfidenceValueType>(probas[0]-probas[1]);
}
else
{
auto max_proba = *(std::max_element(probas.begin(),
probas.end()));
(*quality) = static_cast<ConfidenceValueType>(max_proba);
}
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
unsigned int res;
amc.eval(samples, res);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
return target;
}
template <class TInputValue, class TOutputValue>
void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::PredictAll()
{
std::vector<shark::RealVector> features;
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
TargetListSampleType * targets = this->GetTargetListSample();
targets->Clear();
for(const auto& p : prediction.elements())
{
TargetSampleType target;
target[0] = static_cast<TOutputValue>(p);
targets->PushBack(target);
}
}
template <class TInputValue, class TOutputValue>
void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Save(const std::string & filename, const std::string & itkNotUsed(name))
{
std::ofstream ofs(filename.c_str());
boost::archive::polymorphic_text_oarchive oa(ofs);
m_RFModel.save(oa,0);
}
template <class TInputValue, class TOutputValue>
void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Load(const std::string & filename, const std::string & itkNotUsed(name))
{
std::ifstream ifs(filename.c_str());
boost::archive::polymorphic_text_iarchive ia(ifs);
m_RFModel.load(ia,0);
}
template <class TInputValue, class TOutputValue>
bool
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const std::string & file)
{
try
{
this->Load(file);
}
catch(...)
{
return false;
}
return true;
}
template <class TInputValue, class TOutputValue>
bool
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const std::string & itkNotUsed(file))
{
return true;
}
template <class TInputValue, class TOutputValue>
void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
} //end namespace otb
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbSharkRandomForestsMachineLearningModelFactory_h
#define __otbSharkRandomForestsMachineLearningModelFactory_h
#include "itkObjectFactoryBase.h"
#include "itkImageIOBase.h"
namespace otb
{