Skip to content
Snippets Groups Projects
Commit 1fec8ea2 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: factorise duplicate code for classification and regression

parent 178c955a
No related branches found
No related tags found
No related merge requests found
......@@ -193,6 +193,10 @@ private:
NeuralNetworkMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
void CreateNetwork();
CvANN_MLP_TrainParams SetNetworkParameters();
void SetupNetworkAndTrain(cv::Mat& labels);
CvANN_MLP * m_ANNModel;
int m_TrainMethod;
int m_ActivateFunction;
......
......@@ -130,9 +130,8 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::LabelsToMat(c
}
}
/** Train the machine learning model for classification*/
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassification()
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::CreateNetwork()
{
//Create the neural network
const unsigned int nbLayers = m_LayerSizes.size();
......@@ -147,14 +146,11 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassifi
}
m_ANNModel->create(layers, m_ActivateFunction, m_Alpha, m_Beta);
}
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
cv::Mat matOutputANN;
LabelsToMat(this->GetTargetListSample(), matOutputANN);
template<class TInputValue, class TOutputValue>
CvANN_MLP_TrainParams NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::SetNetworkParameters()
{
CvANN_MLP_TrainParams params;
params.train_method = m_TrainMethod;
params.bp_dw_scale = m_BackPropDWScale;
......@@ -163,9 +159,29 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassifi
params.rp_dw_min = m_RegPropDWMin;
CvTermCriteria term_crit = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
params.term_crit = term_crit;
return params;
}
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::SetupNetworkAndTrain(cv::Mat& labels)
{
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
this->CreateNetwork();
CvANN_MLP_TrainParams params = this->SetNetworkParameters();
//train the Neural network model
m_ANNModel->train(samples, matOutputANN, cv::Mat(), cv::Mat(), params);
m_ANNModel->train(samples, labels, cv::Mat(), cv::Mat(), params);
}
/** Train the machine learning model for classification*/
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassification()
{
//Transform the targets into a matrix of labels
cv::Mat matOutputANN;
LabelsToMat(this->GetTargetListSample(), matOutputANN);
this->SetupNetworkAndTrain(matOutputANN);
}
template<class TInputValue, class TOutputValue>
......@@ -217,38 +233,10 @@ typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSam
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainRegression()
{
//Create the neural network
const unsigned int nbLayers = m_LayerSizes.size();
if ( nbLayers == 0 )
itkExceptionMacro(<< "Number of layers in the Neural Network must be >= 3")
cv::Mat layers = cv::Mat(nbLayers, 1, CV_32SC1);
for (unsigned int i = 0; i < nbLayers; i++)
{
layers.row(i) = m_LayerSizes[i];
}
m_ANNModel->create(layers, m_ActivateFunction, m_Alpha, m_Beta);
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
//Transform the targets into an OpenCV matrix
cv::Mat matOutputANN;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), matOutputANN);
CvANN_MLP_TrainParams params;
params.train_method = m_TrainMethod;
params.bp_dw_scale = m_BackPropDWScale;
params.bp_moment_scale = m_BackPropMomentScale;
params.rp_dw0 = m_RegPropDW0;
params.rp_dw_min = m_RegPropDWMin;
CvTermCriteria term_crit = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
params.term_crit = term_crit;
//train the Neural network model
m_ANNModel->train(samples, matOutputANN, cv::Mat(), cv::Mat(), params);
this->SetupNetworkAndTrain(matOutputANN);
}
template<class TInputValue, class TOutputValue>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment