otbAutoencoderModel.h 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20 21 22
#ifndef AutoencoderModel_h
#define AutoencoderModel_h

23 24
#include "otbMachineLearningModelTraits.h"
#include "otbMachineLearningModel.h"
25
#include <fstream>
26

27 28 29 30 31 32 33 34
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#endif
#include "otb_shark.h"
#include <shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h>
35 36
#include <shark/Models/FFNet.h>
#include <shark/Models/Autoencoder.h>
37 38 39 40
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif

41 42
namespace otb
{
43
template <class TInputValue, class NeuronType>
44
class ITK_EXPORT AutoencoderModel: public  MachineLearningModel<itk::VariableLengthVector< TInputValue> , itk::VariableLengthVector< TInputValue>>   
45 46 47 48
{

public:
	
49 50
  typedef AutoencoderModel Self;
  typedef MachineLearningModel<itk::VariableLengthVector< TInputValue> , itk::VariableLengthVector< TInputValue>> Superclass;
51 52 53
	typedef itk::SmartPointer<Self> Pointer;
	typedef itk::SmartPointer<const Self> ConstPointer;

54 55 56 57 58 59 60 61 62
	typedef typename Superclass::InputValueType 			InputValueType;
	typedef typename Superclass::InputSampleType 			InputSampleType;
	typedef typename Superclass::InputListSampleType 		InputListSampleType;
	typedef typename InputListSampleType::Pointer 			ListSamplePointerType;
	typedef typename Superclass::TargetValueType 			TargetValueType;
	typedef typename Superclass::TargetSampleType 			TargetSampleType;
	typedef typename Superclass::TargetListSampleType 		TargetListSampleType;

	/// Confidence map related typedefs
Cédric Traizet's avatar
Cédric Traizet committed
63
	
64 65 66 67
	typedef typename Superclass::ConfidenceValueType  				ConfidenceValueType;
	typedef typename Superclass::ConfidenceSampleType 				ConfidenceSampleType;
	typedef typename Superclass::ConfidenceListSampleType      		ConfidenceListSampleType;

68
	/// Neural network related typedefs
69
	//typedef shark::Autoencoder<NeuronType,shark::LinearNeuron> OutAutoencoderType;
70 71 72 73 74
	typedef shark::Autoencoder<NeuronType,shark::LinearNeuron> OutAutoencoderType;
	typedef shark::Autoencoder<NeuronType,NeuronType> AutoencoderType;
	typedef shark::FFNet<NeuronType,shark::LinearNeuron> NetworkType;
	
	
75
	itkNewMacro(Self);
76
	itkTypeMacro(AutoencoderModel, DimensionalityReductionModel);
77

78
	//unsigned int GetDimension() {return m_NumberOfHiddenNeurons[m_net.size()-1];};  // Override the Dimensionality Reduction model method, it is used in the dimensionality reduction filter to set the output image size
79 80
	itkGetMacro(NumberOfHiddenNeurons,itk::Array<unsigned int>);
	itkSetMacro(NumberOfHiddenNeurons,itk::Array<unsigned int>);
81 82 83

	itkGetMacro(NumberOfIterations,unsigned int);
	itkSetMacro(NumberOfIterations,unsigned int);
84
	
85 86 87
	itkGetMacro(NumberOfIterationsFineTuning,unsigned int);
	itkSetMacro(NumberOfIterationsFineTuning,unsigned int);
	
88 89
	itkGetMacro(Epsilon,double);
	itkSetMacro(Epsilon,double);
90

91 92 93
	itkGetMacro(InitFactor,double);
	itkSetMacro(InitFactor,double);

94 95
	itkGetMacro(Regularization,itk::Array<double>);
	itkSetMacro(Regularization,itk::Array<double>);
96

97 98
	itkGetMacro(Noise,itk::Array<double>);
	itkSetMacro(Noise,itk::Array<double>);
99

100 101
	itkGetMacro(Rho,itk::Array<double>);
	itkSetMacro(Rho,itk::Array<double>);
102

103 104
	itkGetMacro(Beta,itk::Array<double>);
	itkSetMacro(Beta,itk::Array<double>);
105

106 107
	itkGetMacro(WriteLearningCurve,bool);
	itkSetMacro(WriteLearningCurve,bool);
108 109 110
		
	itkSetMacro(WriteWeights, bool);
	itkGetMacro(WriteWeights, bool);
111 112 113 114
	
	itkGetMacro(LearningCurveFileName,std::string);
	itkSetMacro(LearningCurveFileName,std::string);

115 116 117 118 119 120 121
	bool CanReadFile(const std::string & filename);
	bool CanWriteFile(const std::string & filename);

	void Save(const std::string & filename, const std::string & name="")  ITK_OVERRIDE;
	void Load(const std::string & filename, const std::string & name="")  ITK_OVERRIDE;

	void Train() ITK_OVERRIDE;
122
	
Cédric Traizet's avatar
Cédric Traizet committed
123 124
	template <class T, class Autoencoder>
	void TrainOneLayer(shark::AbstractStoppingCriterion<T> & criterion,Autoencoder &,unsigned int, unsigned int,double, double, shark::Data<shark::RealVector> &, std::ostream&);
125
	
126 127
	template <class T, class Autoencoder>
	void TrainOneSparseLayer(shark::AbstractStoppingCriterion<T> & criterion,Autoencoder &, unsigned int, unsigned int,double, double,double, shark::Data<shark::RealVector> &, std::ostream&);
128
	
Cédric Traizet's avatar
Cédric Traizet committed
129 130 131
	template <class T>
	void TrainNetwork(shark::AbstractStoppingCriterion<T> & criterion,double, double,double, shark::Data<shark::RealVector> &, std::ostream&);
	
132
protected:
Cédric Traizet's avatar
Cédric Traizet committed
133
	AutoencoderModel();	
134
	~AutoencoderModel() ITK_OVERRIDE;
Cédric Traizet's avatar
Cédric Traizet committed
135
 
136 137 138
	virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType * quality = ITK_NULLPTR) const;

	virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * quality = ITK_NULLPTR) const;
Cédric Traizet's avatar
Cédric Traizet committed
139 140 141
  
private:
	
142
	/** Network attributes */
143 144
	//std::vector<AutoencoderType> m_net;
	NetworkType m_net;
145
	itk::Array<unsigned int> m_NumberOfHiddenNeurons;
146
	/** Training parameters */
147
	unsigned int m_NumberOfIterations; // stop the training after a fixed number of iterations
148
	unsigned int m_NumberOfIterationsFineTuning; // stop the fine tuning after a fixed number of iterations
149
	double m_Epsilon; // Stops the training when the training error seems to converge
150 151 152 153
	itk::Array<double> m_Regularization;  // L2 Regularization parameter
	itk::Array<double> m_Noise;  // probability for an input to be set to 0 (denosing autoencoder)
	itk::Array<double> m_Rho; // Sparsity parameter
	itk::Array<double> m_Beta; // Sparsity regularization parameter
154
	double m_InitFactor; // Weight initialization factor (the weights are intialized at m_initfactor/sqrt(inputDimension)  )
155 156
	
	bool m_WriteLearningCurve; // Flag for writting the learning curve into a txt file
157
	std::string m_LearningCurveFileName; // Name of the output learning curve printed after training
158
	bool m_WriteWeights;
159 160 161 162 163
};
} // end namespace otb


#ifndef OTB_MANUAL_INSTANTIATION
164
#include "otbAutoencoderModel.txx"
165 166 167 168
#endif


#endif
169