From fa0bf983e6347752545a1fbed3e6b1b99703d206 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <traizetc@cesbio.cnes.fr> Date: Thu, 11 May 2017 14:36:54 +0200 Subject: [PATCH] dr application working (monoband output) for autoencoders and tiedautoencoders --- app/cbDimensionalityReduction.cxx | 2 +- include/AutoencoderModelFactory.h | 31 +++++++++++++++++++++-------- include/AutoencoderModelFactory.txx | 19 +++++++++--------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/app/cbDimensionalityReduction.cxx b/app/cbDimensionalityReduction.cxx index 08769213b9..d1e4861be1 100644 --- a/app/cbDimensionalityReduction.cxx +++ b/app/cbDimensionalityReduction.cxx @@ -194,7 +194,7 @@ private: otbAppLogINFO("Loading model"); m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode); - otbAppLogINFO("yo"); + if (m_Model.IsNull()) { otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); diff --git a/include/AutoencoderModelFactory.h b/include/AutoencoderModelFactory.h index 51847dfebe..0b8f538272 100644 --- a/include/AutoencoderModelFactory.h +++ b/include/AutoencoderModelFactory.h @@ -2,18 +2,20 @@ #define AutoencoderModelFactory_h +#include <shark/Models/TiedAutoencoder.h> +#include <shark/Models/Autoencoder.h> #include "itkObjectFactoryBase.h" #include "itkImageIOBase.h" namespace otb { -template <class TInputValue, class TTargetValue> -class ITK_EXPORT AutoencoderModelFactory : public itk::ObjectFactoryBase +template <class TInputValue, class TTargetValue, class AutoencoderType> +class ITK_EXPORT AutoencoderModelFactoryBase : public itk::ObjectFactoryBase { public: /** Standard class typedefs. */ - typedef AutoencoderModelFactory Self; + typedef AutoencoderModelFactoryBase Self; typedef itk::ObjectFactoryBase Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; @@ -26,26 +28,39 @@ public: itkFactorylessNewMacro(Self); /** Run-time type information (and related methods). */ - itkTypeMacro(AutoencoderModelFactory, itk::ObjectFactoryBase); + itkTypeMacro(AutoencoderModelFactoryBase, itk::ObjectFactoryBase); /** Register one factory of this type */ static void RegisterOneFactory(void) { - Pointer AEFactory = AutoencoderModelFactory::New(); + Pointer AEFactory = AutoencoderModelFactoryBase::New(); itk::ObjectFactoryBase::RegisterFactory(AEFactory); } protected: - AutoencoderModelFactory(); - ~AutoencoderModelFactory() ITK_OVERRIDE; + AutoencoderModelFactoryBase(); + ~AutoencoderModelFactoryBase() ITK_OVERRIDE; private: - AutoencoderModelFactory(const Self &); //purposely not implemented + AutoencoderModelFactoryBase(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented }; + + + + +template <class TInputValue, class TTargetValue> +class ITK_EXPORT AutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; + + +template <class TInputValue, class TTargetValue> +class ITK_EXPORT TiedAutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; + + } //namespace otb + #ifndef OTB_MANUAL_INSTANTIATION #include "AutoencoderModelFactory.txx" #endif diff --git a/include/AutoencoderModelFactory.txx b/include/AutoencoderModelFactory.txx index 8b35be2881..22d6c5ba9a 100644 --- a/include/AutoencoderModelFactory.txx +++ b/include/AutoencoderModelFactory.txx @@ -25,11 +25,10 @@ #include "AutoencoderModel.h" #include "itkVersion.h" -#include <shark/Models/Autoencoder.h>//normal autoencoder model namespace otb { -template <class TInputValue, class TOutputValue> -AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory() +template <class TInputValue, class TOutputValue, class AutoencoderType> +AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::AutoencoderModelFactoryBase() { std::string classOverride = std::string("otbMachineLearningModel"); @@ -40,22 +39,22 @@ AutoencoderModelFactory<TInputValue,TOutputValue>::AutoencoderModelFactory() "Shark RF ML Model", 1, // itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New()); - itk::CreateObjectFunction<AutoencoderModel<TInputValue,shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> > >::New()); + itk::CreateObjectFunction<AutoencoderModel<TInputValue,AutoencoderType > >::New()); } -template <class TInputValue, class TOutputValue> -AutoencoderModelFactory<TInputValue,TOutputValue>::~AutoencoderModelFactory() +template <class TInputValue, class TOutputValue, class AutoencoderType> +AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::~AutoencoderModelFactoryBase() { } -template <class TInputValue, class TOutputValue> -const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const +template <class TInputValue, class TOutputValue, class AutoencoderType> +const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetITKSourceVersion(void) const { return ITK_SOURCE_VERSION; } -template <class TInputValue, class TOutputValue> -const char* AutoencoderModelFactory<TInputValue,TOutputValue>::GetDescription() const +template <class TInputValue, class TOutputValue, class AutoencoderType> +const char* AutoencoderModelFactoryBase<TInputValue,TOutputValue, AutoencoderType>::GetDescription() const { return "Autoencoder model factory"; } -- GitLab