Skip to content
Snippets Groups Projects
Commit d413ed56 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: change the interface of otb::MachineLearningModel to support regression

parent 3ae2dca2
No related branches found
No related tags found
No related merge requests found
......@@ -94,10 +94,10 @@ public:
//@}
/** Train the machine learning model */
virtual void Train() = 0;
void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType& input) const = 0;
TargetSampleType Predict(const InputSampleType& input) const;
/** Classify all samples in InputListSample and fill TargetListSample with the associated label */
void PredictAll();
......@@ -134,6 +134,12 @@ public:
itkGetObjectMacro(TargetListSample,TargetListSampleType);
//@}
/**\name Classification vs regression mode accessors */
//@{
itkSetMacro(RegressionMode,bool);
itkGetMacro(RegressionMode,bool);
//@}
protected:
/** Constructor */
MachineLearningModel();
......@@ -150,6 +156,14 @@ protected:
/** Target list sample */
typename TargetListSampleType::Pointer m_TargetListSample;
/** Train the machine learning model */
virtual void TrainRegression() = 0;
virtual void TrainClassification() = 0;
/** Predict values using the model */
virtual TargetSampleType PredictRegression(const InputSampleType& input) const = 0;
virtual TargetSampleType PredictClassification(const InputSampleType& input) const = 0;
bool m_RegressionMode;
private:
MachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -25,7 +25,7 @@ namespace otb
template <class TInputValue, class TOutputValue>
MachineLearningModel<TInputValue,TOutputValue>
::MachineLearningModel()
::MachineLearningModel() : m_RegressionMode(false)
{}
......@@ -34,6 +34,28 @@ MachineLearningModel<TInputValue,TOutputValue>
::~MachineLearningModel()
{}
template <class TInputValue, class TOutputValue>
void
MachineLearningModel<TInputValue,TOutputValue>
::Train()
{
if(m_RegressionMode)
return this->TrainRegression();
else
return this->TrainClassification();
}
template <class TInputValue, class TOutputValue>
typename MachineLearningModel<TInputValue,TOutputValue>::TargetSampleType
MachineLearningModel<TInputValue,TOutputValue>
::Predict(const typename MachineLearningModel<TInputValue,TOutputValue>::InputSampleType& input) const
{
if(m_RegressionMode)
return this->PredictRegression(input);
else
return this->PredictClassification(input);
}
template <class TInputValue, class TOutputValue>
void
MachineLearningModel<TInputValue,TOutputValue>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment