Skip to content
Snippets Groups Projects
Commit 31fb8edd authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: implement TrainRegression for ANN

parent 858248d0
No related branches found
No related tags found
No related merge requests found
......@@ -178,11 +178,14 @@ protected:
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Train the machine learning model */
/** Train the machine learning model for classidication*/
virtual void TrainClassification();
/** Predict values using the model */
/** Predict values using the model in classification mode*/
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
/** Train the machine learning model for regression*/
virtual void TrainRegression();
private:
NeuralNetworkMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -130,7 +130,7 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::LabelsToMat(c
}
}
/** Train the machine learning model */
/** Train the machine learning model for classification*/
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainClassification()
{
......@@ -213,6 +213,44 @@ typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSam
return target;
}
/** Train the machine learning model for regression*/
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainRegression()
{
//Create the neural network
const unsigned int nbLayers = m_LayerSizes.size();
if ( nbLayers == 0 )
itkExceptionMacro(<< "Number of layers in the Neural Network must be >= 3")
cv::Mat layers = cv::Mat(nbLayers, 1, CV_32SC1);
for (unsigned int i = 0; i < nbLayers; i++)
{
layers.row(i) = m_LayerSizes[i];
}
m_ANNModel->create(layers, m_ActivateFunction, m_Alpha, m_Beta);
//convert listsample to opencv matrix
cv::Mat samples;
otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
cv::Mat matOutputANN;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(), matOutputANN);
CvANN_MLP_TrainParams params;
params.train_method = m_TrainMethod;
params.bp_dw_scale = m_BackPropDWScale;
params.bp_moment_scale = m_BackPropMomentScale;
params.rp_dw0 = m_RegPropDW0;
params.rp_dw_min = m_RegPropDWMin;
CvTermCriteria term_crit = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
params.term_crit = term_crit;
//train the Neural network model
m_ANNModel->train(samples, matOutputANN, cv::Mat(), cv::Mat(), params);
}
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string & filename,
const std::string & name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment