Commit 229fd4f6 authored by Jordi Inglada's avatar Jordi Inglada

MRG

parents 98ee6bb7 95ac5ea9
......@@ -94,10 +94,10 @@ public:
//@}
/** Train the machine learning model */
virtual void Train() = 0;
void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType& input) const = 0;
TargetSampleType Predict(const InputSampleType& input) const;
/** Classify all samples in InputListSample and fill TargetListSample with the associated label */
void PredictAll();
......@@ -150,6 +150,22 @@ protected:
/** Target list sample */
typename TargetListSampleType::Pointer m_TargetListSample;
/** Train the machine learning model */
virtual void TrainRegression()
{
itkGenericExceptionMacro(<< "Regression mode not implemented.");
}
virtual void TrainClassification() = 0;
/** Predict values using the model */
virtual TargetSampleType PredictRegression(const InputSampleType& input) const
{
itkGenericExceptionMacro(<< "Regression mode not implemented.");
(void)input;
}
virtual TargetSampleType PredictClassification(const InputSampleType& input) const = 0;
bool m_RegressionMode;
private:
MachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -25,7 +25,7 @@ namespace otb
template <class TInputValue, class TOutputValue>
MachineLearningModel<TInputValue,TOutputValue>
::MachineLearningModel()
::MachineLearningModel() : m_RegressionMode(false)
{}
......@@ -34,6 +34,28 @@ MachineLearningModel<TInputValue,TOutputValue>
::~MachineLearningModel()
{}
template <class TInputValue, class TOutputValue>
void
MachineLearningModel<TInputValue,TOutputValue>
::Train()
{
if(m_RegressionMode)
return this->TrainRegression();
else
return this->TrainClassification();
}
template <class TInputValue, class TOutputValue>
typename MachineLearningModel<TInputValue,TOutputValue>::TargetSampleType
MachineLearningModel<TInputValue,TOutputValue>
::Predict(const typename MachineLearningModel<TInputValue,TOutputValue>::InputSampleType& input) const
{
if(m_RegressionMode)
return this->PredictRegression(input);
else
return this->PredictClassification(input);
}
template <class TInputValue, class TOutputValue>
void
MachineLearningModel<TInputValue,TOutputValue>
......
......@@ -94,13 +94,6 @@ public:
itkGetMacro(MaxDepth, int);
itkSetMacro(MaxDepth, int);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -126,6 +119,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
BoostMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -51,7 +51,7 @@ BoostMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
BoostMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -76,7 +76,7 @@ template <class TInputValue, class TOutputValue>
typename BoostMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
BoostMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......
......@@ -155,12 +155,6 @@ public:
itkGetMacro(IsRegression, bool);
itkSetMacro(IsRegression, bool);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -186,6 +180,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
DecisionTreeMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -55,7 +55,7 @@ DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -83,7 +83,7 @@ template <class TInputValue, class TOutputValue>
typename DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
DecisionTreeMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......
......@@ -102,13 +102,6 @@ public:
itkGetMacro(UseSurrogates, bool);
itkSetMacro(UseSurrogates, bool);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -134,6 +127,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
GradientBoostedTreeMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -53,7 +53,7 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -80,7 +80,7 @@ template <class TInputValue, class TOutputValue>
typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......
......@@ -67,12 +67,6 @@ public:
itkGetMacro(IsRegression, bool);
itkSetMacro(IsRegression, bool);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -98,6 +92,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
KNearestNeighborsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -49,7 +49,7 @@ KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
template <class TInputValue, class TTargetValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -66,7 +66,7 @@ template <class TInputValue, class TTargetValue>
typename KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::TargetSampleType
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......@@ -181,7 +181,7 @@ KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
this->SetInputListSample(samples);
this->SetTargetListSample(labels);
Train();
this->Train();
}
template <class TInputValue, class TTargetValue>
......
......@@ -61,12 +61,6 @@ public:
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string &filename, const std::string & name="");
......@@ -111,6 +105,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
LibSVMMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -55,7 +55,7 @@ LibSVMMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
// Set up SVM's parameters
// CvSVMParams params;
......@@ -80,7 +80,7 @@ template <class TInputValue, class TOutputValue>
typename LibSVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
TargetSampleType target;
......@@ -114,28 +114,28 @@ bool
LibSVMMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const std::string & file)
{
//TODO: Rework.
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
//Read only the first line.
std::string line;
std::getline(ifs, line);
//if (line.find(m_SVMModel->getName()) != std::string::npos)
if (line.find("svm_type") != std::string::npos)
{
//std::cout<<"Reading a libSVM model"<<std::endl;
return true;
}
ifs.close();
return false;
//TODO: Rework.
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
//Read only the first line.
std::string line;
std::getline(ifs, line);
//if (line.find(m_SVMModel->getName()) != std::string::npos)
if (line.find("svm_type") != std::string::npos)
{
//std::cout<<"Reading a libSVM model"<<std::endl;
return true;
}
ifs.close();
return false;
}
template <class TInputValue, class TOutputValue>
......
......@@ -153,12 +153,6 @@ public:
itkGetMacro(Epsilon, double);
itkSetMacro(Epsilon, double);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -186,6 +180,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
NeuralNetworkMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -131,7 +131,7 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::LabelsToMat(c
/** Train the machine learning model */
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::Train()
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassification()
{
//Create the neural network
const unsigned int nbLayers = m_LayerSizes.size();
......@@ -169,7 +169,7 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::Train()
template<class TInputValue, class TOutputValue>
typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSampleType NeuralNetworkMachineLearningModel<
TInputValue, TOutputValue>::Predict(const InputSampleType & input) const
TInputValue, TOutputValue>::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......
......@@ -54,12 +54,6 @@ public:
itkNewMacro(Self);
itkTypeMacro(NormalBayesMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -85,6 +79,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
NormalBayesMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -45,7 +45,7 @@ NormalBayesMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
NormalBayesMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -61,7 +61,7 @@ template <class TInputValue, class TOutputValue>
typename NormalBayesMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
NormalBayesMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......
......@@ -61,12 +61,6 @@ public:
itkNewMacro(Self);
itkTypeMacro(RandomForestsMachineLearningModel, itk::MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -147,6 +141,11 @@ protected:
/* /\** Target list sample *\/ */
/* typename TargetListSampleType::Pointer m_TargetListSample; */
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
RandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -64,7 +64,7 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -109,7 +109,7 @@ template <class TInputValue, class TOutputValue>
typename RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & value) const
::PredictClassification(const InputSampleType & value) const
{
//convert listsample to Mat
cv::Mat sample;
......
......@@ -53,12 +53,6 @@ public:
itkNewMacro(Self);
itkTypeMacro(SVMMachineLearningModel, MachineLearningModel);
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType & input) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -134,6 +128,11 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
virtual void TrainClassification();
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input) const;
private:
SVMMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -63,7 +63,7 @@ SVMMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
SVMMachineLearningModel<TInputValue,TOutputValue>
::Train()
::TrainClassification()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -108,7 +108,7 @@ template <class TInputValue, class TOutputValue>
typename SVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
SVMMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & input) const
::PredictClassification(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment