Skip to content
Snippets Groups Projects
Commit 24aabd70 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: implement prediction for ANN regression

parent 31fb8edd
No related branches found
No related tags found
No related merge requests found
......@@ -186,6 +186,9 @@ protected:
/** Train the machine learning model for regression*/
virtual void TrainRegression();
/** Predict values using the model for regression*/
virtual TargetSampleType PredictRegression(const InputSampleType& input) const;
private:
NeuralNetworkMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......
......@@ -223,7 +223,7 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainRegressi
if ( nbLayers == 0 )
itkExceptionMacro(<< "Number of layers in the Neural Network must be >= 3")
cv::Mat layers = cv::Mat(nbLayers, 1, CV_32SC1);
cv::Mat layers = cv::Mat(nbLayers, 1, CV_32SC1);
for (unsigned int i = 0; i < nbLayers; i++)
{
layers.row(i) = m_LayerSizes[i];
......@@ -251,6 +251,23 @@ void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TrainRegressi
m_ANNModel->train(samples, matOutputANN, cv::Mat(), cv::Mat(), params);
}
template<class TInputValue, class TOutputValue>
typename NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::TargetSampleType NeuralNetworkMachineLearningModel<
TInputValue, TOutputValue>::PredictRegression(const InputSampleType & input) const
{
//convert listsample to Mat
cv::Mat sample;
otb::SampleToMat<InputSampleType>(input, sample);
cv::Mat response;
m_ANNModel->predict(sample, response);
TargetSampleType target;
target[0] = response.at<float> (0, 0);
return target;
}
template<class TInputValue, class TOutputValue>
void NeuralNetworkMachineLearningModel<TInputValue, TOutputValue>::Save(const std::string & filename,
const std::string & name)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment