-
Jordi Inglada authoredJordi Inglada authored
otbRandomForestsMachineLearningModel.txx 6.70 KiB
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbRandomForestsMachineLearningModel_txx
#define __otbRandomForestsMachineLearningModel_txx
#include <fstream>
#include "otbRandomForestsMachineLearningModel.h"
#include "otbOpenCVUtils.h"
#include <opencv2/opencv.hpp> // opencv general include file
namespace otb
{
template <class TInputValue, class TOutputValue>
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::RandomForestsMachineLearningModel() :
m_RFModel (new CvRTrees),
m_MaxDepth(5),
m_MinSampleCount(10),
m_RegressionAccuracy(0),
m_ComputeSurrogateSplit(false),
m_MaxNumberOfCategories(10),
m_CalculateVariableImportance(false),
m_MaxNumberOfVariables(0),
m_MaxNumberOfTrees(100),
m_ForestAccuracy(0.01),
m_TerminationCriteria(CV_TERMCRIT_ITER | CV_TERMCRIT_EPS),
m_RegressionMode(false)
{
}
template <class TInputValue, class TOutputValue>
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::~RandomForestsMachineLearningModel()
{
delete m_RFModel;
}
template <class TInputValue, class TOutputValue>
float
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::GetTrainError()
{
return m_RFModel->get_train_error();
}
/** Train the machine learning model */
template <class TInputValue, class TOutputValue>
void
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Train()
{
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
cv::Mat labels;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
//Mat var_type = Mat(ATTRIBUTES_PER_SAMPLE + 1, 1, CV_8U );
//std::cout << "priors " << m_Priors[0] << std::endl;
//Define random forests paramneters
//FIXME do this in the constructor?
float * priors = m_Priors.empty() ? 0 : &m_Priors.front();
CvRTParams params = CvRTParams(m_MaxDepth, // max depth
m_MinSampleCount, // min sample count
m_RegressionAccuracy, // regression accuracy: N/A here
m_ComputeSurrogateSplit, // compute surrogate split, no missing data
m_MaxNumberOfCategories, // max number of categories (use sub-optimal algorithm for larger numbers)
priors, // the array of priors
m_CalculateVariableImportance, // calculate variable importance
m_MaxNumberOfVariables, // number of variables randomly selected at node and used to find the best split(s).
m_MaxNumberOfTrees, // max number of trees in the forest
m_ForestAccuracy, // forest accuracy
m_TerminationCriteria // termination criteria
);
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
if(m_RegressionMode)
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_NUMERICAL;
else
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
//train the RT model
m_RFModel->train(samples, CV_ROW_SAMPLE, labels,
cv::Mat(), cv::Mat(), var_type, cv::Mat(), params);
}
template <class TInputValue, class TOutputValue>
typename RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & value) const
{
//convert listsample to Mat
cv::Mat sample;
otb::SampleToMat<InputSampleType>(value,sample);
double result = m_RFModel->predict(sample);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(result);
return target[0];
}
template <class TInputValue, class TOutputValue>
void
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Save(const std::string & filename, const std::string & name)
{
if (name == "")
m_RFModel->save(filename.c_str(), 0);
else
m_RFModel->save(filename.c_str(), name.c_str());
}
template <class TInputValue, class TOutputValue>
void
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Load(const std::string & filename, const std::string & name)
{
if (name == "")
m_RFModel->load(filename.c_str(), 0);
else
m_RFModel->load(filename.c_str(), name.c_str());
}
template <class TInputValue, class TOutputValue>
bool
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::CanReadFile(const std::string & file)
{
std::ifstream ifs;
ifs.open(file.c_str());
if(!ifs)
{
std::cerr<<"Could not read file "<<file<<std::endl;
return false;
}
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
//if (line.find(m_RFModel->getName()) != std::string::npos)
if (line.find(CV_TYPE_NAME_ML_RTREES) != std::string::npos)
{
//std::cout<<"Reading a "<<CV_TYPE_NAME_ML_RTREES<<" model"<<std::endl;
return true;
}
}
ifs.close();
return false;
}
template <class TInputValue, class TOutputValue>
bool
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::CanWriteFile(const std::string & itkNotUsed(file))
{
return false;
}
template <class TInputValue, class TOutputValue>
typename RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::VariableImportanceMatrixType
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::GetVariableImportance()
{
cv::Mat cvMat = m_RFModel->getVarImportance();
VariableImportanceMatrixType itkMat(cvMat.rows,cvMat.cols);
for(unsigned int i =0; i<cvMat.rows; i++)
{
for(unsigned int j =0; j<cvMat.cols; j++)
{
itkMat(i,j)=cvMat.at<float>(i,j);
}
}
return itkMat;
}
template <class TInputValue, class TOutputValue>
void
RandomForestsMachineLearningModel<TInputValue,TOutputValue>
::PrintSelf(std::ostream& os, itk::Indent indent) const
{
// Call superclass implementation
Superclass::PrintSelf(os,indent);
}
} //end namespace otb
#endif