otbAutoencoderModel.h 6.76 KB
Newer Older
Guillaume Pasero's avatar
Guillaume Pasero committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20 21
#ifndef otbAutoencoderModel_h
#define otbAutoencoderModel_h
22

23 24
#include "otbMachineLearningModelTraits.h"
#include "otbMachineLearningModel.h"
25
#include <string>
26

27 28 29 30 31
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
32
#pragma GCC diagnostic ignored "-Wsign-compare"
33
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
34 35 36 37
#if defined(__clang__)
#pragma clang diagnostic ignored "-Wheader-guard"
#pragma clang diagnostic ignored "-Wdivision-by-zero"
#endif
38 39 40
#endif
#include "otb_shark.h"
#include <shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h>
41 42 43
#include <shark/Models/LinearModel.h>
#include <shark/Models/ConcatenatedModel.h>
#include <shark/Models/NeuronLayers.h>
44 45 46 47
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif

48 49
namespace otb
{
Guillaume Pasero's avatar
Guillaume Pasero committed
50 51 52 53 54 55 56
/**
 * \class AutoencoderModel
 *
 * Autoencoder model wrapper class
 *
 * \ingroup OTBDimensionalityReductionLearning
 */
57
template <class TInputValue, class NeuronType>
Guillaume Pasero's avatar
Guillaume Pasero committed
58 59 60 61
class ITK_EXPORT AutoencoderModel
  : public  MachineLearningModel<
    itk::VariableLengthVector< TInputValue>,
    itk::VariableLengthVector< TInputValue> >
62 63
{
public:
64
  typedef AutoencoderModel Self;
Guillaume Pasero's avatar
Guillaume Pasero committed
65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  typedef MachineLearningModel<
    itk::VariableLengthVector< TInputValue>,
    itk::VariableLengthVector< TInputValue> > Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  typedef typename Superclass::InputValueType       InputValueType;
  typedef typename Superclass::InputSampleType      InputSampleType;
  typedef typename Superclass::InputListSampleType    InputListSampleType;
  typedef typename InputListSampleType::Pointer       ListSamplePointerType;
  typedef typename Superclass::TargetValueType      TargetValueType;
  typedef typename Superclass::TargetSampleType       TargetSampleType;
  typedef typename Superclass::TargetListSampleType     TargetListSampleType;

  /// Confidence map related typedefs
  typedef typename Superclass::ConfidenceValueType          ConfidenceValueType;
  typedef typename Superclass::ConfidenceSampleType         ConfidenceSampleType;
  typedef typename Superclass::ConfidenceListSampleType         ConfidenceListSampleType;

  /// Neural network related typedefs
85 86 87
  typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
  typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
  typedef shark::LinearModel<shark::RealVector, shark::LinearNeuron> OutLayerType;
Guillaume Pasero's avatar
Guillaume Pasero committed
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

  itkNewMacro(Self);
  itkTypeMacro(AutoencoderModel, DimensionalityReductionModel);

  itkGetMacro(NumberOfHiddenNeurons,itk::Array<unsigned int>);
  itkSetMacro(NumberOfHiddenNeurons,itk::Array<unsigned int>);

  itkGetMacro(NumberOfIterations,unsigned int);
  itkSetMacro(NumberOfIterations,unsigned int);

  itkGetMacro(NumberOfIterationsFineTuning,unsigned int);
  itkSetMacro(NumberOfIterationsFineTuning,unsigned int);

  itkGetMacro(Epsilon,double);
  itkSetMacro(Epsilon,double);

  itkGetMacro(InitFactor,double);
  itkSetMacro(InitFactor,double);

  itkGetMacro(Regularization,itk::Array<double>);
  itkSetMacro(Regularization,itk::Array<double>);

  itkGetMacro(Noise,itk::Array<double>);
  itkSetMacro(Noise,itk::Array<double>);

  itkGetMacro(Rho,itk::Array<double>);
  itkSetMacro(Rho,itk::Array<double>);

  itkGetMacro(Beta,itk::Array<double>);
  itkSetMacro(Beta,itk::Array<double>);

  itkGetMacro(WriteLearningCurve,bool);
  itkSetMacro(WriteLearningCurve,bool);

  itkSetMacro(WriteWeights, bool);
  itkGetMacro(WriteWeights, bool);

  itkGetMacro(LearningCurveFileName,std::string);
  itkSetMacro(LearningCurveFileName,std::string);

128 129
  bool CanReadFile(const std::string & filename) override;
  bool CanWriteFile(const std::string & filename) override;
Guillaume Pasero's avatar
Guillaume Pasero committed
130

131 132
  void Save(const std::string & filename, const std::string & name="")  override;
  void Load(const std::string & filename, const std::string & name="")  override;
Guillaume Pasero's avatar
Guillaume Pasero committed
133

134
  void Train() override;
Guillaume Pasero's avatar
Guillaume Pasero committed
135

136
  template <class T>
Guillaume Pasero's avatar
Guillaume Pasero committed
137 138 139 140 141 142
  void TrainOneLayer(
    shark::AbstractStoppingCriterion<T> & criterion,
    unsigned int,
    shark::Data<shark::RealVector> &,
    std::ostream&);

143
  template <class T>
Guillaume Pasero's avatar
Guillaume Pasero committed
144 145 146 147 148 149 150 151 152 153 154 155
  void TrainOneSparseLayer(
    shark::AbstractStoppingCriterion<T> & criterion,
    unsigned int,
    shark::Data<shark::RealVector> &,
    std::ostream&);

  template <class T>
  void TrainNetwork(
    shark::AbstractStoppingCriterion<T> & criterion,
    shark::Data<shark::RealVector> &,
    std::ostream&);

156
protected:
Guillaume Pasero's avatar
Guillaume Pasero committed
157
  AutoencoderModel();
158
  ~AutoencoderModel() override;
Guillaume Pasero's avatar
Guillaume Pasero committed
159 160 161

  virtual TargetSampleType DoPredict(
    const InputSampleType& input,
162
    ConfidenceValueType * quality = nullptr) const override;
Guillaume Pasero's avatar
Guillaume Pasero committed
163 164 165 166 167 168

  virtual void DoPredictBatch(
    const InputListSampleType *,
    const unsigned int & startIndex,
    const unsigned int & size,
    TargetListSampleType *,
169
    ConfidenceListSampleType * quality = nullptr) const override;
170

Cédric Traizet's avatar
Cédric Traizet committed
171
private:
Guillaume Pasero's avatar
Guillaume Pasero committed
172
  /** Internal Network */
173 174 175
  ModelType m_Encoder;
  std::vector<LayerType> m_InLayers;
  OutLayerType m_OutLayer;
Guillaume Pasero's avatar
Guillaume Pasero committed
176 177 178 179 180 181 182 183 184 185 186
  itk::Array<unsigned int> m_NumberOfHiddenNeurons;
  /** Training parameters */
  unsigned int m_NumberOfIterations; // stop the training after a fixed number of iterations
  unsigned int m_NumberOfIterationsFineTuning; // stop the fine tuning after a fixed number of iterations
  double m_Epsilon; // Stops the training when the training error seems to converge
  itk::Array<double> m_Regularization;  // L2 Regularization parameter
  itk::Array<double> m_Noise;  // probability for an input to be set to 0 (denosing autoencoder)
  itk::Array<double> m_Rho; // Sparsity parameter
  itk::Array<double> m_Beta; // Sparsity regularization parameter
  double m_InitFactor; // Weight initialization factor (the weights are intialized at m_initfactor/sqrt(inputDimension)  )

187
  bool m_WriteLearningCurve; // Flag for writing the learning curve into a txt file
Guillaume Pasero's avatar
Guillaume Pasero committed
188 189
  std::string m_LearningCurveFileName; // Name of the output learning curve printed after training
  bool m_WriteWeights;
190 191 192 193
};
} // end namespace otb

#ifndef OTB_MANUAL_INSTANTIATION
194
#include "otbAutoencoderModel.hxx"
195 196 197
#endif

#endif
198