Commit c88bcd17 authored by Guillaume Pasero's avatar Guillaume Pasero

DOC: document LearningApplicationBase

parent 1b0ea6d4
......@@ -50,6 +50,29 @@ namespace otb
namespace Wrapper
/** \class LearningApplicationBase
* \brief LearningApplicationBase is the base class for application that
* use machine learning model.
* This base class offers a DoInit() method to initialize all the parameters
* related to machine learning models. They will all be in the choice parameter
* named "classifier". The class also offers generic Train() and Classify()
* methods. The classes derived from LearningApplicationBase only need these
* 3 methods to handle the machine learning model.
* There are multiple machine learning models in OTB, some imported from OpenCV,
* and one imported from LibSVM. They all have different parameters. The
* purpose of this class is to handle the creation of all parameters related to
* machine learning models (in DoInit() ), and to dispatch the calls to
* specific train functions in function Train(). This class also handles the
* two learning modes : classification and regression. By default,
* classification mode is enabled. For regression, child classes should
* initialize the m_RegressionFlag to true in their constructor and use a
* continuous numeric type as output template TOutputValue.
* \sa TrainImagesClassifier
* \sa TrainRegression
template <class TInputValue, class TOutputValue>
class LearningApplicationBase: public Application
......@@ -98,24 +121,29 @@ public:
// using Superclass::AddParameter;
// friend void InitSVMParams(LearningApplicationBase & app);
/** Generic method to train and save the machine learning model. This method
* uses specific train methods depending on the chosen model.*/
void Train(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
/** Generic method to load a model file and use it to classify a sample list*/
void Classify(typename ListSampleType::Pointer validationListSample,
typename TargetListSampleType::Pointer predictedList,
std::string modelPath);
/** Init method that creates all the parameters for machine learning models */
void DoInit();
/** Flag to switch between classification and regression mode.
* False by default, child classes may change it in their constructor */
bool m_RegressionFlag;
/** Specific Init and Train methods for each machine learning model */
void InitLibSVMParams();
......@@ -159,6 +187,7 @@ private:
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
......@@ -37,6 +37,7 @@ LearningApplicationBase<TInputValue,TOutputValue>
// main choice parameter that will contain all machine learning options
AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
......@@ -70,7 +71,7 @@ LearningApplicationBase<TInputValue,TOutputValue>
typename TargetListSampleType::Pointer predictedList,
std::string modelPath)
// load a machine learning model from file and predict the input sample list
ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath,
......@@ -93,7 +94,9 @@ LearningApplicationBase<TInputValue,TOutputValue>
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath)
// get the name of the chosen machine learning model
const std::string modelName = GetParameterString("classifier");
// call specific train function
if (modelName == "libsvm")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment