Skip to content
Snippets Groups Projects
Commit 9c68b6e2 authored by Arnaud Jaen's avatar Arnaud Jaen
Browse files

ENH: Add Save and Load method for KNN machine learning model.

parent 05a28a67
No related branches found
No related tags found
No related merge requests found
......@@ -18,16 +18,19 @@
#ifndef __otbKNearestNeighborsMachineLearningModel_txx
#define __otbKNearestNeighborsMachineLearningModel_txx
#include <iostream>
#include <boost/lexical_cast.hpp>
#include "otbKNearestNeighborsMachineLearningModel.h"
#include "otbOpenCVUtils.h"
#include <opencv2/opencv.hpp>
#include "itkMacro.h"
namespace otb
{
template <class TInputValue, class TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TTargetValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::KNearestNeighborsMachineLearningModel() :
m_K(10), m_IsRegression(false)
{
......@@ -35,17 +38,17 @@ KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
}
template <class TInputValue, class TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TTargetValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::~KNearestNeighborsMachineLearningModel()
{
delete m_KNearestModel;
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
template <class TInputValue, class TTargetValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::Train()
{
//convert listsample to opencv matrix
......@@ -59,10 +62,10 @@ KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
m_KNearestModel->train(samples,labels,cv::Mat(),m_IsRegression, m_K,false);
}
template <class TInputValue, class TOutputValue>
typename KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TTargetValue>
typename KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::TargetSampleType
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::Predict(const InputSampleType & input) const
{
//convert listsample to Mat
......@@ -73,47 +76,141 @@ KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
TargetSampleType target;
target[0] = static_cast<TOutputValue>(result);
target[0] = static_cast<TTargetValue>(result);
return target;
}
template <class TInputValue, class TOutputValue>
template <class TInputValue, class TTargetValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::Save(const std::string & filename, const std::string & name)
{
m_KNearestModel->save(filename.c_str(), name.c_str());
//there is no m_KNearestModel->save(filename.c_str(), name.c_str()).
//We need to save the K parameter and IsRegression flag used and the samples.
std::ofstream ofs(filename.c_str());
//Save K parameter and IsRegression flag.
ofs << "K="<< m_K <<"\n";
ofs << "IsRegression="<<m_IsRegression <<"\n";
//Save the samples. First column is the Label and other columns are the sample data.
typename InputListSampleType::ConstIterator sampleIt = this->GetInputListSample()->Begin();
typename TargetListSampleType::ConstIterator labelIt = this->GetTargetListSample()->Begin();
const unsigned int sampleSize = this->GetInputListSample()->GetMeasurementVectorSize();
for(; sampleIt!=this->GetInputListSample()->End(); ++sampleIt,++labelIt)
{
// Retrieve sample
typename InputListSampleType::MeasurementVectorType sample = sampleIt.GetMeasurementVector();
ofs <<labelIt.GetMeasurementVector()[0];
// Loop on sample size
for(unsigned int i = 0; i < sampleSize; ++i)
{
ofs << " " << sample[i];
}
ofs <<"\n";
}
ofs.close();
}
template <class TInputValue, class TOutputValue>
template <class TInputValue, class TTargetValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::Load(const std::string & filename, const std::string & name)
{
m_KNearestModel->load(filename.c_str(), name.c_str());
//there is no m_KNearestModel->load(filename.c_str(), name.c_str());
std::ifstream ifs(filename.c_str());
if(!ifs)
{
itkExceptionMacro(<<"Could not read file "<<filename);
}
//first line is the K parameter of this algorithm.
std::string line;
std::getline(ifs, line);
std::string::size_type pos = line.find_first_of("=", 0);
std::string::size_type nextpos = line.find_first_of(" \n\r", pos+1);
this->SetK(boost::lexical_cast<int>(line.substr(pos+1, nextpos-pos-1)));
//second line is the IsRegression parameter
std::getline(ifs, line);
pos = line.find_first_of("=", 0);
nextpos = line.find_first_of(" \n\r", pos+1);
this->SetIsRegression(boost::lexical_cast<bool>(line.substr(pos+1, nextpos-pos-1)));
//Clear previous listSample (if any)
typename InputListSampleType::Pointer samples = InputListSampleType::New();
typename TargetListSampleType::Pointer labels = TargetListSampleType::New();
//Read a txt file. First column is the label, other columns are the sample data.
unsigned int nbFeatures = 0;
while (!ifs.eof())
{
std::getline(ifs, line);
if(nbFeatures == 0)
{
nbFeatures = std::count(line.begin(),line.end(),' ');
}
if(line.size()>1)
{
// Parse label
pos = line.find_first_of(" ", 0);
TargetSampleType label;
label[0] = static_cast<TargetValueType>(boost::lexical_cast<unsigned int>(line.substr(0, pos)));
// Parse sample features
InputSampleType sample(nbFeatures);
sample.Fill(0);
unsigned int id = 0;
nextpos = line.find_first_of(" ", pos+1);
while(nextpos != std::string::npos)
{
nextpos = line.find_first_of(" \n\r", pos+1);
std::string subline = line.substr(pos+1, nextpos-pos-1);
//sample[id] = static_cast<InputValueType>(boost::lexical_cast<float>(subline));
sample[id] = atof(subline.c_str());
pos = nextpos;
id++;
}
samples->PushBack(sample);
labels->PushBack(label);
}
}
ifs.close();
this->SetInputListSample(samples);
this->SetTargetListSample(labels);
Train();
}
template <class TInputValue, class TOutputValue>
template <class TInputValue, class TTargetValue>
bool
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::CanReadFile(const std::string & file)
{
return false;
try
{
this->Load(file);
}
catch(...)
{
return false;
}
return true;
}
template <class TInputValue, class TOutputValue>
template <class TInputValue, class TTargetValue>
bool
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::CanWriteFile(const std::string & file)
{
return false;
}
template <class TInputValue, class TOutputValue>
template <class TInputValue, class TTargetValue>
void
KNearestNeighborsMachineLearningModel<TInputValue,TOutputValue>
KNearestNeighborsMachineLearningModel<TInputValue,TTargetValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment