From fa0bf983e6347752545a1fbed3e6b1b99703d206 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <traizetc@cesbio.cnes.fr>
Date: Thu, 11 May 2017 14:36:54 +0200
Subject: [PATCH] dr application working (monoband output) for autoencoders and
 tiedautoencoders

---
 app/cbDimensionalityReduction.cxx   |  2 +-
 include/AutoencoderModelFactory.h   | 31 +++++++++++++++++++++--------
 include/AutoencoderModelFactory.txx | 19 +++++++++---------
 3 files changed, 33 insertions(+), 19 deletions(-)

diff --git a/app/cbDimensionalityReduction.cxx b/app/cbDimensionalityReduction.cxx
index 08769213b9..d1e4861be1 100644
--- a/app/cbDimensionalityReduction.cxx
+++ b/app/cbDimensionalityReduction.cxx
@@ -194,7 +194,7 @@ private:
     otbAppLogINFO("Loading model");
     m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
                                                                           MachineLearningModelFactoryType::ReadMode);
-    otbAppLogINFO("yo");
+
     if (m_Model.IsNull())
       {
       otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
diff --git a/include/AutoencoderModelFactory.h b/include/AutoencoderModelFactory.h
index 51847dfebe..0b8f538272 100644
--- a/include/AutoencoderModelFactory.h
+++ b/include/AutoencoderModelFactory.h
@@ -2,18 +2,20 @@
 #define AutoencoderModelFactory_h
 
 
+#include <shark/Models/TiedAutoencoder.h>
+#include <shark/Models/Autoencoder.h>
 #include "itkObjectFactoryBase.h"
 #include "itkImageIOBase.h"
 
 namespace otb
 {
 	
-template <class TInputValue, class TTargetValue>
-class ITK_EXPORT AutoencoderModelFactory : public itk::ObjectFactoryBase
+template <class TInputValue, class TTargetValue, class AutoencoderType>
+class ITK_EXPORT AutoencoderModelFactoryBase : public itk::ObjectFactoryBase
 {
 public:
   /** Standard class typedefs. */
-  typedef AutoencoderModelFactory             Self;
+  typedef AutoencoderModelFactoryBase   Self;
   typedef itk::ObjectFactoryBase        Superclass;
   typedef itk::SmartPointer<Self>       Pointer;
   typedef itk::SmartPointer<const Self> ConstPointer;
@@ -26,26 +28,39 @@ public:
   itkFactorylessNewMacro(Self);
 
   /** Run-time type information (and related methods). */
-  itkTypeMacro(AutoencoderModelFactory, itk::ObjectFactoryBase);
+  itkTypeMacro(AutoencoderModelFactoryBase, itk::ObjectFactoryBase);
 
   /** Register one factory of this type  */
   static void RegisterOneFactory(void)
   {
-    Pointer AEFactory = AutoencoderModelFactory::New();
+    Pointer AEFactory = AutoencoderModelFactoryBase::New();
     itk::ObjectFactoryBase::RegisterFactory(AEFactory);
   }
 
 protected:
-  AutoencoderModelFactory();
-  ~AutoencoderModelFactory() ITK_OVERRIDE;
+  AutoencoderModelFactoryBase();
+  ~AutoencoderModelFactoryBase() ITK_OVERRIDE;
 
 private:
-  AutoencoderModelFactory(const Self &); //purposely not implemented
+  AutoencoderModelFactoryBase(const Self &); //purposely not implemented
   void operator =(const Self&); //purposely not implemented
 
 };
+
+
+
+
+template <class TInputValue, class TTargetValue>
+class ITK_EXPORT AutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>>  {};
+
+
+template <class TInputValue, class TTargetValue>
+class ITK_EXPORT TiedAutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>>  {};
+
+
 } //namespace otb
 
+
 #ifndef OTB_MANUAL_INSTANTIATION
 #include "AutoencoderModelFactory.txx"
 #endif
diff --git a/include/AutoencoderModelFactory.txx b/include/AutoencoderModelFactory.txx
index 8b35be2881..22d6c5ba9a 100644
--- a/include/AutoencoderModelFactory.txx
+++ b/include/AutoencoderModelFactory.txx
@@ -25,11 +25,10 @@
 #include "AutoencoderModel.h"
 #include "itkVersion.h"
 
-#include <shark/Models/Autoencoder.h>//normal autoencoder model
 namespace otb
 {
-template <class TInputValue, class TOutputValue>
-AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory()
+template <class TInputValue, class TOutputValue, class AutoencoderType>
+AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::AutoencoderModelFactoryBase()
 {
 
   std::string classOverride = std::string("otbMachineLearningModel");
@@ -40,22 +39,22 @@ AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory()
                          "Shark RF ML Model",
                          1,
                       //   itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New());
-						itk::CreateObjectFunction<AutoencoderModel<TInputValue,shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> > >::New());
+						itk::CreateObjectFunction<AutoencoderModel<TInputValue,AutoencoderType > >::New());
 }
 
-template <class TInputValue, class TOutputValue>
-AutoencoderModelFactory<TInputValue,TOutputValue>::~AutoencoderModelFactory()
+template <class TInputValue, class TOutputValue, class AutoencoderType>
+AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::~AutoencoderModelFactoryBase()
 {
 }
 
-template <class TInputValue, class TOutputValue>
-const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const
+template <class TInputValue, class TOutputValue, class AutoencoderType>
+const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetITKSourceVersion(void) const
 {
   return ITK_SOURCE_VERSION;
 }
 
-template <class TInputValue, class TOutputValue>
-const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetDescription() const
+template <class TInputValue, class TOutputValue, class AutoencoderType>
+const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetDescription() const
 {
   return "Autoencoder model factory";
 }
-- 
GitLab