Skip to content
Snippets Groups Projects
Commit a9f9b090 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: implement SVM regression mode

parent b73f681c
Branches
Tags
No related merge requests found
......@@ -126,14 +126,21 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
/** Train the machine learning model for classification*/
virtual void TrainClassification();
/** Predict values using the model */
/** Predict values using the model for classification*/
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
/** Train the machine learning model for regression*/
virtual void TrainRegression();
/** Predict values using the model for regression*/
virtual TargetSampleType PredictRegression(const InputSampleType& input) const;
private:
SVMMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
virtual void InternalTrain();
virtual TargetSampleType InternalPredict(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
CvSVM * m_SVMModel;
int m_SVMType;
......
......@@ -64,7 +64,7 @@ SVMMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
SVMMachineLearningModel<TInputValue,TOutputValue>
::TrainClassification()
::InternalTrain()
{
//convert listsample to opencv matrix
cv::Mat samples;
......@@ -105,11 +105,21 @@ SVMMachineLearningModel<TInputValue,TOutputValue>
}
/** Train the machine learning model for classification*/
template <class TInputValue, class TOutputValue>
void
SVMMachineLearningModel<TInputValue,TOutputValue>
::TrainClassification()
{
m_SVMType = CvSVM::C_SVC;
this->InternalTrain();
}
template <class TInputValue, class TOutputValue>
typename SVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
SVMMachineLearningModel<TInputValue,TOutputValue>
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
::InternalPredict(const InputSampleType & input, ConfidenceValueType *quality) const
{
//convert listsample to Mat
cv::Mat sample;
......@@ -130,15 +140,43 @@ SVMMachineLearningModel<TInputValue,TOutputValue>
return target;
}
template <class TInputValue, class TOutputValue>
typename SVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
SVMMachineLearningModel<TInputValue,TOutputValue>
::PredictClassification(const InputSampleType & input, ConfidenceValueType *quality) const
{
return this->InternalPredict(input, quality);
}
/** Train the machine learning model for regression*/
template <class TInputValue, class TOutputValue>
void
SVMMachineLearningModel<TInputValue,TOutputValue>
::TrainRegression()
{
m_SVMType = CvSVM::NU_SVR;
this->InternalTrain();
}
template <class TInputValue, class TOutputValue>
typename SVMMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
SVMMachineLearningModel<TInputValue,TOutputValue>
::PredictRegression(const InputSampleType & input) const
{
return this->InternalPredict(input, false);
}
template <class TInputValue, class TOutputValue>
void
SVMMachineLearningModel<TInputValue,TOutputValue>
::Save(const std::string & filename, const std::string & name)
{
if (name == "")
m_SVMModel->save(filename.c_str(), 0);
else
m_SVMModel->save(filename.c_str(), name.c_str());
if (name == "")
m_SVMModel->save(filename.c_str(), 0);
else
m_SVMModel->save(filename.c_str(), name.c_str());
}
template <class TInputValue, class TOutputValue>
......@@ -147,9 +185,9 @@ SVMMachineLearningModel<TInputValue,TOutputValue>
::Load(const std::string & filename, const std::string & name)
{
if (name == "")
m_SVMModel->load(filename.c_str(), 0);
m_SVMModel->load(filename.c_str(), 0);
else
m_SVMModel->load(filename.c_str(), name.c_str());
m_SVMModel->load(filename.c_str(), name.c_str());
}
template <class TInputValue, class TOutputValue>
......@@ -157,22 +195,22 @@ bool
SVMMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const std::string & file)
{
std::ifstream ifs;
ifs.open(file.c_str());
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
//if (line.find(m_SVMModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_SVM) != std::string::npos)
//if (line.find(m_SVMModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_SVM) != std::string::npos)
{
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_SVM<<" model"<<std::endl;
return true;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment