otbPCAModel.h 3.89 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20 21
#ifndef otbPCAModel_h
#define otbPCAModel_h
22

23 24 25
#include "otbMachineLearningModelTraits.h"
#include "otbMachineLearningModel.h"

26 27 28 29 30
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
31
#pragma GCC diagnostic ignored "-Wsign-compare"
32
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
33 34 35
#if defined(__clang__)
#pragma clang diagnostic ignored "-Wheader-guard"
#endif
36 37
#endif
#include "otb_shark.h"
38
#include <shark/Algorithms/Trainers/PCA.h>
39 40 41
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
42 43 44

namespace otb
{
45 46 47 48

/** \class PCAModel
 *
 * This class wraps a PCA model implemented by Shark, in a otb::MachineLearningModel
49 50
 *
 * \ingroup OTBDimensionalityReductionLearning
51
 */
52
template <class TInputValue>
53 54 55 56
class ITK_EXPORT PCAModel
  : public  MachineLearningModel<
    itk::VariableLengthVector< TInputValue >,
    itk::VariableLengthVector< TInputValue > >    
57 58
{
public:
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
  typedef PCAModel Self;
  typedef MachineLearningModel<
    itk::VariableLengthVector< TInputValue >,
    itk::VariableLengthVector< TInputValue> > Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  typedef typename Superclass::InputValueType       InputValueType;
  typedef typename Superclass::InputSampleType      InputSampleType;
  typedef typename Superclass::InputListSampleType  InputListSampleType;
  typedef typename InputListSampleType::Pointer     ListSamplePointerType;
  typedef typename Superclass::TargetValueType      TargetValueType;
  typedef typename Superclass::TargetSampleType     TargetSampleType;
  typedef typename Superclass::TargetListSampleType TargetListSampleType;

  // Confidence map related typedefs
  typedef typename Superclass::ConfidenceValueType       ConfidenceValueType;
  typedef typename Superclass::ConfidenceSampleType      ConfidenceSampleType;
  typedef typename Superclass::ConfidenceListSampleType  ConfidenceListSampleType;

  itkNewMacro(Self);
  itkTypeMacro(PCAModel, DimensionalityReductionModel);

  itkSetMacro(DoResizeFlag,bool);

  itkSetMacro(WriteEigenvectors, bool);
  itkGetMacro(WriteEigenvectors, bool);

87 88
  bool CanReadFile(const std::string & filename) override;
  bool CanWriteFile(const std::string & filename) override;
89

90 91
  void Save(const std::string & filename, const std::string & name="")  override;
  void Load(const std::string & filename, const std::string & name="")  override;
92

93
  void Train() override;
94 95

protected:
96
  PCAModel(); 
97
  ~PCAModel() override;
98
 
99 100
  virtual TargetSampleType DoPredict(
    const InputSampleType& input,
101
    ConfidenceValueType * quality = nullptr) const override;
102 103 104 105 106 107

  virtual void DoPredictBatch(
    const InputListSampleType *,
    const unsigned int & startIndex,
    const unsigned int & size,
    TargetListSampleType *,
108
    ConfidenceListSampleType * quality = nullptr) const override;
109

110
private:
111 112 113 114 115
  shark::LinearModel<> m_Encoder;
  shark::LinearModel<> m_Decoder;
  shark::PCA m_PCA;
  bool m_DoResizeFlag;
  bool m_WriteEigenvectors;
116 117 118 119 120
};
} // end namespace otb


#ifndef OTB_MANUAL_INSTANTIATION
121
#include "otbPCAModel.hxx"
122 123 124 125 126
#endif


#endif