From b4f2fd41e2021882a1c5870f2561bcb27f4b5c75 Mon Sep 17 00:00:00 2001
From: Arnaud Jaen <arnaud.jaen@c-s.fr>
Date: Wed, 3 Apr 2013 11:13:04 +0200
Subject: [PATCH] ENH: Add Neural Network machine learning model.

---
 .../OpenCV/otbMachineLearningModelFactory.txx |   2 +
 .../otbNeuralNetworkMachineLearningModel.h    | 210 ++++++++++++++++++
 .../otbNeuralNetworkMachineLearningModel.txx  | 193 ++++++++++++++++
 ...NeuralNetworkMachineLearningModelFactory.h |  72 ++++++
 ...uralNetworkMachineLearningModelFactory.txx |  64 ++++++
 Testing/Code/Learning/CMakeLists.txt          |  19 +-
 Testing/Code/Learning/otbLearningTests6.cxx   |   3 +
 .../otbMachineLearningModelCanRead.cxx        |  28 +++
 .../Learning/otbTrainMachineLearningModel.cxx |  67 ++++++
 9 files changed, 657 insertions(+), 1 deletion(-)
 create mode 100644 Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.h
 create mode 100644 Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.txx
 create mode 100644 Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.h
 create mode 100644 Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.txx

diff --git a/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx b/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx
index cee47bf37f..1d8bf19db3 100644
--- a/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx
+++ b/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx
@@ -24,6 +24,7 @@
 #include "otbSVMMachineLearningModelFactory.h"
 #include "otbLibSVMMachineLearningModelFactory.h"
 #include "otbBoostMachineLearningModelFactory.h"
+#include "otbNeuralNetworkMachineLearningModelFactory.h"
 
 
 namespace otb
@@ -97,6 +98,7 @@ MachineLearningModelFactory<TInputValue,TOutputValue>
       itk::ObjectFactoryBase::RegisterFactory(LibSVMMachineLearningModelFactory<TInputValue,TOutputValue>::New());
       itk::ObjectFactoryBase::RegisterFactory(SVMMachineLearningModelFactory<TInputValue,TOutputValue>::New());
       itk::ObjectFactoryBase::RegisterFactory(BoostMachineLearningModelFactory<TInputValue,TOutputValue>::New());
+      itk::ObjectFactoryBase::RegisterFactory(NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>::New());
 
       firstTime = false;
       }
diff --git a/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.h
new file mode 100644
index 0000000000..029e7b32cc
--- /dev/null
+++ b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.h
@@ -0,0 +1,210 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef __otbNeuralNetworkMachineLearningModel_h
+#define __otbNeuralNetworkMachineLearningModel_h
+
+#include "itkLightObject.h"
+#include "itkVariableLengthVector.h"
+#include "itkFixedArray.h"
+#include "itkListSample.h"
+#include "otbMachineLearningModel.h"
+
+class CvANN_MLP;
+
+namespace otb
+{
+template <class TInputValue, class TTargetValue>
+class ITK_EXPORT NeuralNetworkMachineLearningModel
+  : public MachineLearningModel <TInputValue, TTargetValue>
+{
+public:
+  /** Standard class typedefs. */
+  typedef NeuralNetworkMachineLearningModel           Self;
+  typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
+  typedef itk::SmartPointer<Self>                         Pointer;
+  typedef itk::SmartPointer<const Self>                   ConstPointer;
+
+  // Input related typedefs
+  typedef TInputValue                                     InputValueType;
+  typedef itk::VariableLengthVector<InputValueType>       InputSampleType;
+  typedef itk::Statistics::ListSample<InputSampleType>    InputListSampleType;
+
+  // Target related typedefs
+  typedef TTargetValue                                    TargetValueType;
+  typedef itk::FixedArray<TargetValueType,1>              TargetSampleType;
+  typedef itk::Statistics::ListSample<TargetSampleType>   TargetListSampleType;
+
+  /** Run-time type information (and related methods). */
+  itkNewMacro(Self);
+  itkTypeMacro(NeuralNetworkMachineLearningModel, itk::MachineLearningModel);
+
+  /** Setters/Getters to the train method
+   *  2 methods are available:
+   *   - CvANN_MLP_TrainParams::BACKPROP The back-propagation algorithm.
+   *   - CvANN_MLP_TrainParams::RPROP The RPROP algorithm.
+   *  Default is CvANN_MLP_TrainParams::RPROP.
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(TrainMethod, int);
+  itkSetMacro(TrainMethod, int);
+
+  /**
+   * Set the number of neurons in each layer (including input and output layers).
+   * The number of neuron in the first layer (input layer) must be equal
+   * to the number of samples in the \c InputListSample
+   */
+  void SetLayerSizes (const std::vector<unsigned int> layers);
+
+
+  /** Setters/Getters to the neuron activation function
+   *  3 methods are available:
+   *   - CvANN_MLP::IDENTITY
+   *   - CvANN_MLP::SIGMOID_SYM
+   *   - CvANN_MLP::GAUSSIAN
+   *  Default is CvANN_MLP::SIGMOID_SYM
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(ActivateFunction, int);
+  itkSetMacro(ActivateFunction, int);
+
+  /** Setters/Getters to the alpha parameter of the activation function
+   *  Default is 1
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(Alpha, double);
+  itkSetMacro(Alpha, double);
+
+  /** Setters/Getters to the beta parameter of the activation function
+   *  Default is 1
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(Beta, double);
+  itkSetMacro(Beta, double);
+
+  /** Strength of the weight gradient term in the BACKPROP method.
+   *  The recommended value is about 0.1
+   *  Default is 0.1
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(BackPropDWScale, double);
+  itkSetMacro(BackPropDWScale, double);
+
+  /** Strength of the momentum term (the difference between weights on the 2 previous iterations).
+   *  This parameter provides some inertia to smooth the random fluctuations of the weights.
+   *  It can vary from 0 (the feature is disabled) to 1 and beyond.
+   *  The value 0.1 or so is good enough
+   *  Default is 0.1
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(BackPropMomentScale, double);
+  itkSetMacro(BackPropMomentScale, double);
+
+  /** Initial value \Delta_0 of update-values \Delta_{ij} in RPROP method.
+   *  Default is 0.1
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(RegPropDW0, double);
+  itkSetMacro(RegPropDW0, double);
+
+  /** Update-values lower limit \Delta_{min} in RPROP method.
+   * It must be positive. Default is FLT_EPSILON
+   *  \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(RegPropDWMin, double);
+  itkSetMacro(RegPropDWMin, double);
+
+  /** Termination criteria.
+   * It can be CV_TERMCRIT_ITER or CV_TERMCRIT_EPS or CV_TERMCRIT_ITER+CV_TERMCRIT_EPS
+   * default is CV_TERMCRIT_ITER+CV_TERMCRIT_EPS.
+   * \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(TermCriteriaType, int);
+  itkSetMacro(TermCriteriaType, int);
+
+  /** Maximum number of iteration used in the Termination criteria.
+   * default is 1000
+   * \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(MaxIter, int);
+  itkSetMacro(MaxIter, int);
+
+  /** Epsilon value used in the Termination criteria.
+   * default is 0.01
+   * \see http://docs.opencv.org/modules/ml/doc/neural_networks.html
+   */
+  itkGetMacro(Epsilon, double);
+  itkSetMacro(Epsilon, double);
+
+  /** Train the machine learning model */
+  virtual void Train();
+
+  /** Predict values using the model */
+  virtual TargetSampleType Predict(const InputSampleType & input) const;
+
+  /** Save the model to file */
+  virtual void Save(const std::string & filename, const std::string & name="");
+
+  /** Load the model from file */
+  virtual void Load(const std::string & filename, const std::string & name="");
+
+  /** Determine the file type. Returns true if this ImageIO can read the
+   * file specified. */
+  virtual bool CanReadFile(const std::string &);
+
+  /** Determine the file type. Returns true if this ImageIO can write the
+   * file specified. */
+  virtual bool CanWriteFile(const std::string &);
+
+protected:
+  /** Constructor */
+  NeuralNetworkMachineLearningModel();
+
+  /** Destructor */
+  virtual ~NeuralNetworkMachineLearningModel();
+
+  /** PrintSelf method */
+  void PrintSelf(std::ostream& os, itk::Indent indent) const;
+
+private:
+  NeuralNetworkMachineLearningModel(const Self &); //purposely not implemented
+  void operator =(const Self&); //purposely not implemented
+
+  CvANN_MLP * m_ANNModel;
+  int m_TrainMethod;
+  int m_ActivateFunction;
+  std::vector<unsigned int> m_LayerSizes;
+  double m_Alpha;
+  double m_Beta;
+  double m_BackPropDWScale;
+  double m_BackPropMomentScale;
+  double m_RegPropDW0;
+  double m_RegPropDWMin;
+  int m_TermCriteriaType;
+  int m_MaxIter;
+  double m_Epsilon;
+
+
+
+};
+} // end namespace otb
+
+#ifndef OTB_MANUAL_INSTANTIATION
+#include "otbNeuralNetworkMachineLearningModel.txx"
+#endif
+
+#endif
diff --git a/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.txx
new file mode 100644
index 0000000000..4015d36f1a
--- /dev/null
+++ b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModel.txx
@@ -0,0 +1,193 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef __otbNeuralNetworkMachineLearningModel_txx
+#define __otbNeuralNetworkMachineLearningModel_txx
+
+
+#include "otbNeuralNetworkMachineLearningModel.h"
+#include "otbOpenCVUtils.h"
+#include "itkMacro.h" // itkExceptionMacro
+#include <opencv2/opencv.hpp>
+
+namespace otb
+{
+
+template <class TInputValue, class TOutputValue>
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::NeuralNetworkMachineLearningModel() :
+ m_TrainMethod(CvANN_MLP_TrainParams::RPROP), m_ActivateFunction(CvANN_MLP::SIGMOID_SYM),
+ m_Alpha(0.), m_Beta(0.), m_BackPropDWScale(0.1), m_BackPropMomentScale(0.1),
+ m_RegPropDW0(0.1), m_RegPropDWMin(FLT_EPSILON), m_TermCriteriaType(CV_TERMCRIT_ITER + CV_TERMCRIT_EPS),
+ m_MaxIter(1000), m_Epsilon(0.01)
+{
+  m_ANNModel = new CvANN_MLP;
+  m_LayerSizes.clear();
+}
+
+
+template <class TInputValue, class TOutputValue>
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::~NeuralNetworkMachineLearningModel()
+{
+  delete m_ANNModel;
+}
+
+/** Train the machine learning model */
+template <class TInputValue, class TOutputValue>
+void
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::SetLayerSizes(const std::vector<unsigned int> layers)
+{
+  const unsigned int nbLayers = layers.size();
+  if (nbLayers < 3)
+    itkExceptionMacro(<< "Number of layers in the Neural Network must be >= 3")
+
+  m_LayerSizes = layers;
+}
+
+/** Train the machine learning model */
+template <class TInputValue, class TOutputValue>
+void
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::Train()
+{
+  //Create the neural network
+  const unsigned int nbLayers = m_LayerSizes.size();
+  if (nbLayers == 0)
+    itkExceptionMacro(<< "Number of layers in the Neural Network must be >= 3")
+
+  cv::Mat layers = cv::Mat(nbLayers,1,CV_32SC1);
+  for (unsigned int i=0; i<nbLayers; i++)
+  {
+    layers.row(i) = m_LayerSizes[i];
+  }
+
+  m_ANNModel->create(layers, m_ActivateFunction, m_Alpha, m_Beta);
+
+  //convert listsample to opencv matrix
+  cv::Mat samples;
+  otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples);
+
+  cv::Mat labels;
+  otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
+
+  CvANN_MLP_TrainParams params;
+  params.train_method = m_TrainMethod;
+  params.bp_dw_scale = m_BackPropDWScale;
+  params.bp_moment_scale = m_BackPropMomentScale;
+  params.rp_dw0 = m_RegPropDW0;
+  params.rp_dw_min = m_RegPropDWMin;
+  CvTermCriteria term_crit   = cvTermCriteria(m_TermCriteriaType, m_MaxIter, m_Epsilon);
+  params.term_crit = term_crit;
+
+  //train the Neural network model
+  m_ANNModel->train(samples,labels,cv::Mat(),cv::Mat(),params);
+}
+
+template <class TInputValue, class TOutputValue>
+typename NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::TargetSampleType
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::Predict(const InputSampleType & input) const
+{
+  //convert listsample to Mat
+  cv::Mat sample;
+
+  otb::SampleToMat<InputSampleType>(input,sample);
+
+  cv::Mat response(1, 1, CV_32FC1);
+  m_ANNModel->predict(sample,response);
+
+  TargetSampleType target;
+
+  target[0] = static_cast<TOutputValue>(response.at<float>(0,0));
+
+  return target;
+}
+
+template <class TInputValue, class TOutputValue>
+void
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::Save(const std::string & filename, const std::string & name)
+{
+  if (name == "")
+    m_ANNModel->save(filename.c_str(), 0);
+  else
+    m_ANNModel->save(filename.c_str(), name.c_str());
+}
+
+template <class TInputValue, class TOutputValue>
+void
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::Load(const std::string & filename, const std::string & name)
+{
+  if (name == "")
+    m_ANNModel->load(filename.c_str(), 0);
+  else
+    m_ANNModel->load(filename.c_str(), name.c_str());
+}
+
+template <class TInputValue, class TOutputValue>
+bool
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::CanReadFile(const std::string & file)
+{
+  std::ifstream ifs;
+  ifs.open(file.c_str());
+
+  if(!ifs)
+  {
+    std::cerr<<"Could not read file "<<file<<std::endl;
+    return false;
+  }
+
+  while (!ifs.eof())
+  {
+    std::string line;
+    std::getline(ifs, line);
+
+    if (line.find(CV_TYPE_NAME_ML_ANN_MLP) != std::string::npos)
+    {
+       std::cout<<"Reading a "<<CV_TYPE_NAME_ML_ANN_MLP<<" model !!!"<<std::endl;
+       return true;
+    }
+  }
+  ifs.close();
+  return false;
+}
+
+template <class TInputValue, class TOutputValue>
+bool
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::CanWriteFile(const std::string & file)
+{
+  return false;
+}
+
+template <class TInputValue, class TOutputValue>
+void
+NeuralNetworkMachineLearningModel<TInputValue,TOutputValue>
+::PrintSelf(std::ostream& os, itk::Indent indent) const
+{
+  // Call superclass implementation
+  Superclass::PrintSelf(os,indent);
+}
+
+} //end namespace otb
+
+#endif
diff --git a/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.h b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.h
new file mode 100644
index 0000000000..494149756c
--- /dev/null
+++ b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.h
@@ -0,0 +1,72 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef __otbNeuralNetworkMachineLearningModelFactory_h
+#define __otbNeuralNetworkMachineLearningModelFactory_h
+
+#include "itkObjectFactoryBase.h"
+#include "itkImageIOBase.h"
+
+namespace otb
+{
+/** \class NeuralNetworkMachineLearningModelFactory
+ * \brief Creation d'un instance d'un objet SVMMachineLearningModel utilisant les object factory.
+ */
+template <class TInputValue, class TTargetValue>
+class ITK_EXPORT NeuralNetworkMachineLearningModelFactory : public itk::ObjectFactoryBase
+{
+public:
+  /** Standard class typedefs. */
+  typedef NeuralNetworkMachineLearningModelFactory             Self;
+  typedef itk::ObjectFactoryBase        Superclass;
+  typedef itk::SmartPointer<Self>       Pointer;
+  typedef itk::SmartPointer<const Self> ConstPointer;
+
+  /** Class methods used to interface with the registered factories. */
+  virtual const char* GetITKSourceVersion(void) const;
+  virtual const char* GetDescription(void) const;
+
+  /** Method for class instantiation. */
+  itkFactorylessNewMacro(Self);
+
+  /** Run-time type information (and related methods). */
+  itkTypeMacro(NeuralNetworkMachineLearningModelFactory, itk::ObjectFactoryBase);
+
+  /** Register one factory of this type  */
+  static void RegisterOneFactory(void)
+  {
+    NeuralNetworkMachineLearningModelFactory::Pointer Factory = NeuralNetworkMachineLearningModelFactory::New();
+    itk::ObjectFactoryBase::RegisterFactory(Factory);
+  }
+
+protected:
+  NeuralNetworkMachineLearningModelFactory();
+  virtual ~NeuralNetworkMachineLearningModelFactory();
+
+private:
+  NeuralNetworkMachineLearningModelFactory(const Self &); //purposely not implemented
+  void operator =(const Self&); //purposely not implemented
+
+};
+
+} // end namespace otb
+
+#ifndef OTB_MANUAL_INSTANTIATION
+#include "otbNeuralNetworkMachineLearningModelFactory.txx"
+#endif
+
+#endif
diff --git a/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.txx b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.txx
new file mode 100644
index 0000000000..25a387d97d
--- /dev/null
+++ b/Code/UtilitiesAdapters/OpenCV/otbNeuralNetworkMachineLearningModelFactory.txx
@@ -0,0 +1,64 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#include "otbNeuralNetworkMachineLearningModelFactory.h"
+
+#include "itkCreateObjectFunction.h"
+#include "otbNeuralNetworkMachineLearningModel.h"
+#include "itkVersion.h"
+
+namespace otb
+{
+
+template <class TInputValue, class TOutputValue>
+NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>
+::NeuralNetworkMachineLearningModelFactory()
+{
+
+  static std::string classOverride = std::string("otbMachineLearningModel");
+  static std::string subclass = std::string("otbNeuralNetworkMachineLearningModel");
+
+  this->RegisterOverride(classOverride.c_str(),
+                         subclass.c_str(),
+                         "Artificial Neural Network ML Model",
+                         1,
+                         itk::CreateObjectFunction<NeuralNetworkMachineLearningModel<TInputValue,TOutputValue> >::New());
+}
+
+template <class TInputValue, class TOutputValue>
+NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>
+::~NeuralNetworkMachineLearningModelFactory()
+{
+}
+
+template <class TInputValue, class TOutputValue>
+const char*
+NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>
+::GetITKSourceVersion(void) const
+{
+  return ITK_SOURCE_VERSION;
+}
+
+template <class TInputValue, class TOutputValue>
+const char*
+NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>
+::GetDescription() const
+{
+  return "Artificial Neural Network machine learning model factory";
+}
+
+} // end namespace otb
diff --git a/Testing/Code/Learning/CMakeLists.txt b/Testing/Code/Learning/CMakeLists.txt
index 80e295b52e..686221a969 100644
--- a/Testing/Code/Learning/CMakeLists.txt
+++ b/Testing/Code/Learning/CMakeLists.txt
@@ -728,7 +728,6 @@ IF(OTB_USE_OPENCV)
      ADD_TEST(leTuBoostMachineLearningModelNew ${LEARNING_TESTS6}
            otbBoostMachineLearningModelNew)
 
-
 	ADD_TEST(leTvBoostMachineLearningModel ${LEARNING_TESTS6}
          #--compare-ascii ${NOTOL}
          #${BASELINE_FILES}/BoostLearningModel.txt
@@ -737,6 +736,19 @@ IF(OTB_USE_OPENCV)
          ${INPUTDATA}/letter.scale
          ${TEMP}/BoostMachineLearningModel.txt
          )
+         
+    ADD_TEST(leTuANNMachineLearningModelNew ${LEARNING_TESTS6}
+         otbANNMachineLearningModelNew)
+
+
+	ADD_TEST(leTvANNMachineLearningModel ${LEARNING_TESTS6}
+         #--compare-ascii ${NOTOL}
+         #${BASELINE_FILES}/ANNMachineLearningModel.txt
+         #${TEMP}/ANNMachineLearningModel.txt
+         otbANNMachineLearningModel
+         ${INPUTDATA}/letter.scale
+         ${TEMP}/ANNMachineLearningModel.txt
+         )
      
      ADD_TEST(leTuImageClassificationFilterNew ${LEARNING_TESTS6}
        	 otbImageClassificationFilterNew)
@@ -782,6 +794,11 @@ IF(OTB_USE_OPENCV)
        	 otbBoostMachineLearningModelCanRead
        	 ${INPUTDATA}/boost_model.txt
        	 )
+       	 
+     ADD_TEST(leTuANNMachineLearningModelCanRead ${LEARNING_TESTS6}
+       	 otbNeuralNetworkMachineLearningModelCanRead
+       	 ${INPUTDATA}/NeuralNetworkMachineLearningModel.txt
+       	 )
 
      
 ENDIF(OTB_USE_OPENCV)
diff --git a/Testing/Code/Learning/otbLearningTests6.cxx b/Testing/Code/Learning/otbLearningTests6.cxx
index 9d597f25ca..749102badb 100644
--- a/Testing/Code/Learning/otbLearningTests6.cxx
+++ b/Testing/Code/Learning/otbLearningTests6.cxx
@@ -34,10 +34,13 @@ void RegisterTests()
   REGISTER_TEST(otbRandomForestsMachineLearningModel);
   REGISTER_TEST(otbBoostMachineLearningModelNew);
   REGISTER_TEST(otbBoostMachineLearningModel);
+  REGISTER_TEST(otbANNMachineLearningModelNew);
+  REGISTER_TEST(otbANNMachineLearningModel);
   REGISTER_TEST(otbImageClassificationFilterNew);
   REGISTER_TEST(otbImageClassificationFilter);
   REGISTER_TEST(otbLibSVMMachineLearningModelCanRead);
   REGISTER_TEST(otbSVMMachineLearningModelCanRead);
   REGISTER_TEST(otbRandomForestsMachineLearningModelCanRead);
   REGISTER_TEST(otbBoostMachineLearningModelCanRead);
+  REGISTER_TEST(otbNeuralNetworkMachineLearningModelCanRead);
 }
diff --git a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx
index 88b6537a34..5cb62adee1 100644
--- a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx
+++ b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx
@@ -21,6 +21,7 @@
 #include "otbSVMMachineLearningModel.h"
 #include "otbRandomForestsMachineLearningModel.h"
 #include "otbBoostMachineLearningModel.h"
+#include "otbNeuralNetworkMachineLearningModel.h"
 #include <iostream>
 
 typedef otb::MachineLearningModel<float,short>         MachineLearningModelType;
@@ -135,3 +136,30 @@ int otbBoostMachineLearningModelCanRead(int argc, char* argv[])
   return EXIT_SUCCESS;
 }
 
+int otbNeuralNetworkMachineLearningModelCanRead(int argc, char* argv[])
+{
+  if (argc != 2)
+    {
+    std::cerr << "Usage: " << argv[0]
+              << "<model>" << std::endl;
+    std::cerr << "Called here with " << argc << " arguments\n";
+    for (int i = 1; i < argc; ++i)
+      {
+      std::cerr << " - " << argv[i] << "\n";
+      }
+    return EXIT_FAILURE;
+    }
+  std::string filename(argv[1]);
+  typedef otb::NeuralNetworkMachineLearningModel<InputValueType, TargetValueType> ANNType;
+  ANNType::Pointer classifier = ANNType::New();
+  bool lCanRead = classifier->CanReadFile(filename);
+  if (lCanRead == false)
+    {
+    std::cerr << "Erreur otb::NeuralNetworkMachineLearningModel : impossible to open the file " << filename << "." << std::endl;
+    return EXIT_FAILURE;
+    }
+
+  return EXIT_SUCCESS;
+}
+
+
diff --git a/Testing/Code/Learning/otbTrainMachineLearningModel.cxx b/Testing/Code/Learning/otbTrainMachineLearningModel.cxx
index e19d0e6e8b..ba2e3f2968 100644
--- a/Testing/Code/Learning/otbTrainMachineLearningModel.cxx
+++ b/Testing/Code/Learning/otbTrainMachineLearningModel.cxx
@@ -26,6 +26,7 @@
 #include "otbKNearestNeighborsMachineLearningModel.h"
 #include "otbRandomForestsMachineLearningModel.h"
 #include "otbBoostMachineLearningModel.h"
+#include "otbNeuralNetworkMachineLearningModel.h"
 
 #include "otbConfusionMatrixCalculator.h"
 
@@ -412,5 +413,71 @@ int otbBoostMachineLearningModel(int argc, char * argv[])
   return EXIT_SUCCESS;
 }
 
+int otbANNMachineLearningModelNew(int argc, char * argv[])
+{
+  typedef otb::NeuralNetworkMachineLearningModel<InputValueType, TargetValueType> ANNType;
+  ANNType::Pointer classifier = ANNType::New();
+  return EXIT_SUCCESS;
+}
+
+int otbANNMachineLearningModel(int argc, char * argv[])
+{
+  if (argc != 3)
+    {
+      std::cout<<"Wrong number of arguments "<<std::endl;
+      std::cout<<"Usage : sample file, output file "<<std::endl;
+      return EXIT_FAILURE;
+    }
+
+
+  typedef otb::NeuralNetworkMachineLearningModel<InputValueType, TargetValueType> ANNType;
+  InputListSampleType::Pointer samples = InputListSampleType::New();
+  TargetListSampleType::Pointer labels = TargetListSampleType::New();
+  TargetListSampleType::Pointer predicted = TargetListSampleType::New();
+
+  if (!ReadDataFile(argv[1], samples, labels))
+    {
+    std::cout << "Failed to read samples file " << argv[1] << std::endl;
+    return EXIT_FAILURE;
+    }
+
+  std::cout<<"number of samples = "<<samples->Size()<<std::endl;
+
+  std::vector<unsigned int> layerSizes;
+  layerSizes.push_back(16);
+  layerSizes.push_back(25);
+  layerSizes.push_back(35);
+  layerSizes.push_back(45);
+  layerSizes.push_back(1);
+
+  ANNType::Pointer classifier = ANNType::New();
+  classifier->SetInputListSample(samples);
+  classifier->SetTargetListSample(labels);
+  classifier->SetLayerSizes(layerSizes);
+  classifier->SetTrainMethod(CvANN_MLP_TrainParams::BACKPROP);
+  classifier->SetBackPropDWScale(0.005);
+  classifier->SetBackPropMomentScale(0.005);
+  classifier->Train();
+
+  classifier->SetTargetListSample(predicted);
+  classifier->PredictAll();
+
+  ConfusionMatrixCalculatorType::Pointer cmCalculator = ConfusionMatrixCalculatorType::New();
+
+  cmCalculator->SetProducedLabels(predicted);
+  cmCalculator->SetReferenceLabels(labels);
+  cmCalculator->Compute();
+
+  std::cout << "Confusion matrix: " << std::endl;
+  std::cout << cmCalculator->GetConfusionMatrix() << std::endl;
+  std::cout << "Kappa: " << cmCalculator->GetKappaIndex() << std::endl;
+  std::cout << "Overall Accuracy: " << cmCalculator->GetOverallAccuracy() << std::endl;
+
+  classifier->Save(argv[2]);
+
+  return EXIT_SUCCESS;
+}
+
+
 
 
-- 
GitLab