otbSharkRandomForestsMachineLearningModel.h 6.87 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20

21 22
#ifndef otbSharkRandomForestsMachineLearningModel_h
#define otbSharkRandomForestsMachineLearningModel_h
23 24 25

#include "itkLightObject.h"
#include "otbMachineLearningModel.h"
26 27 28 29 30 31 32

#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
33 34
#pragma GCC diagnostic ignored "-Wsign-compare"
#pragma GCC diagnostic ignored "-Wcast-align"
35
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
36
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
37
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
38 39 40
#if defined(__clang__)
#pragma clang diagnostic ignored "-Wheader-guard"
#endif
41
#endif
42
#include <shark/Models/Classifier.h>
43
#include "otb_shark.h"
44
#include "shark/Algorithms/Trainers/RFTrainer.h"
45
#if defined(__GNUC__) || defined(__clang__)
46 47
#pragma GCC diagnostic pop
#endif
48

49 50 51 52 53 54 55 56 57 58 59

/** \class SharkRandomForestsMachineLearningModel
 *  \brief Shark version of Random Forests algorithm
 *
 *  This is a specialization of MachineLearningModel class allowing to
 *  use Shark implementation of the Random Forests algorithm.
 *
 *  It is noteworthy that training step is parallel.
 * 
 *  For more information, see
 *  http://image.diku.dk/shark/doxygen_pages/html/classshark_1_1_r_f_trainer.html
Julien Michel's avatar
Julien Michel committed
60 61
 * 
 *  \ingroup OTBSupervised
62 63
 */

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
namespace otb
{
template <class TInputValue, class TTargetValue>
class ITK_EXPORT SharkRandomForestsMachineLearningModel
  : public MachineLearningModel <TInputValue, TTargetValue>
{
public:
  /** Standard class typedefs. */
  typedef SharkRandomForestsMachineLearningModel               Self;
  typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
  typedef itk::SmartPointer<Self>                         Pointer;
  typedef itk::SmartPointer<const Self>                   ConstPointer;

  typedef typename Superclass::InputValueType             InputValueType;
  typedef typename Superclass::InputSampleType            InputSampleType;
  typedef typename Superclass::InputListSampleType        InputListSampleType;
  typedef typename Superclass::TargetValueType            TargetValueType;
  typedef typename Superclass::TargetSampleType           TargetSampleType;
  typedef typename Superclass::TargetListSampleType       TargetListSampleType;
  typedef typename Superclass::ConfidenceValueType        ConfidenceValueType;
84 85
  typedef typename Superclass::ConfidenceSampleType       ConfidenceSampleType;
  typedef typename Superclass::ConfidenceListSampleType   ConfidenceListSampleType;
86 87 88 89 90 91
  
  /** Run-time type information (and related methods). */
  itkNewMacro(Self);
  itkTypeMacro(SharkRandomForestsMachineLearningModel, MachineLearningModel);

  /** Train the machine learning model */
92
  virtual void Train() override;
93 94

  /** Save the model to file */
95
  virtual void Save(const std::string & filename, const std::string & name="") override;
96 97

  /** Load the model from file */
98
  virtual void Load(const std::string & filename, const std::string & name="") override;
99 100 101 102

  /**\name Classification model file compatibility tests */
  //@{
  /** Is the input model file readable and compatible with the corresponding classifier ? */
103
  virtual bool CanReadFile(const std::string &) override;
104 105

  /** Is the input model file writable and compatible with the corresponding classifier ? */
106
  virtual bool CanWriteFile(const std::string &) override;
107 108
  //@}

109
  /** From Shark doc: Get the number of trees to grow.*/
110
  itkGetMacro(NumberOfTrees,unsigned int);
111
  /** From Shark doc: Set the number of trees to grow.*/
112 113
  itkSetMacro(NumberOfTrees,unsigned int);

114
  /** From Shark doc: Get the number of random attributes to investigate at each node.*/
115
  itkGetMacro(MTry, unsigned int);
116
  /** From Shark doc: Set the number of random attributes to investigate at each node.*/
117 118
  itkSetMacro(MTry, unsigned int);

119 120 121
  /** From Shark doc: Controls when a node is considered pure. If set
* to 1, a node is pure when it only consists of a single node.
*/
122
  itkGetMacro(NodeSize, unsigned int);
123 124 125
    /** From Shark doc: Controls when a node is considered pure. If
* set to 1, a node is pure when it only consists of a single node.
 */
126
  itkSetMacro(NodeSize, unsigned int);
127 128 129 130

  /** From Shark doc: Get the fraction of the original training
* dataset to use as the out of bag sample. The default value is
* 0.66.*/
131
  itkGetMacro(OobRatio, float);
132 133 134 135

  /** From Shark doc: Set the fraction of the original training
* dataset to use as the out of bag sample. The default value is 0.66.
*/
136 137
  itkSetMacro(OobRatio, float);

138
  /** If true, margin confidence value will be computed */
139
  itkGetMacro(ComputeMargin, bool);
140
  /** If true, margin confidence value will be computed */
141 142
  itkSetMacro(ComputeMargin, bool);

143 144 145 146
  /** If true, class labels will be normalised in [0 ... nbClasses] */
  itkGetMacro(NormalizeClassLabels, bool);
  itkSetMacro(NormalizeClassLabels, bool);

147 148 149 150 151 152 153
protected:
  /** Constructor */
  SharkRandomForestsMachineLearningModel();

  /** Destructor */
  virtual ~SharkRandomForestsMachineLearningModel();

154
  /** Predict values using the model */
155
  virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override;
156 157

  
158
  virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = nullptr) const override;
159
  
160
  /** PrintSelf method */
161
  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
162 163

private:
164 165
  SharkRandomForestsMachineLearningModel(const Self &) = delete;
  void operator =(const Self&) = delete;
166

167 168
  shark::RFClassifier<unsigned int> m_RFModel;
  shark::RFTrainer<unsigned int> m_RFTrainer;
169
  std::vector<unsigned int> m_ClassDictionary;
170
  bool m_NormalizeClassLabels;
171 172 173 174 175 176

  unsigned int m_NumberOfTrees;
  unsigned int m_MTry;
  unsigned int m_NodeSize;
  float m_OobRatio;
  bool m_ComputeMargin;
177 178

  /** Confidence list sample */
179
  ConfidenceValueType ComputeConfidence(shark::RealVector & probas, 
180
                                        bool computeMargin) const;
181

182
};
183 184 185
} // end namespace otb

#ifndef OTB_MANUAL_INSTANTIATION
186
#include "otbSharkRandomForestsMachineLearningModel.hxx"
187 188 189
#endif

#endif