otbBoostMachineLearningModel.h 5.16 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
Julien Michel's avatar
Julien Michel committed
20

21 22
#ifndef otbBoostMachineLearningModel_h
#define otbBoostMachineLearningModel_h
Julien Michel's avatar
Julien Michel committed
23

24 25
#include "otbRequiresOpenCVCheck.h"

Julien Michel's avatar
Julien Michel committed
26 27 28 29
#include "itkLightObject.h"
#include "itkFixedArray.h"
#include "otbMachineLearningModel.h"

30
#ifdef OTB_OPENCV_3
31
#include "otbOpenCVUtils.h"
32
#else
33
class CvBoost;
34
#endif
Julien Michel's avatar
Julien Michel committed
35 36 37 38

namespace otb
{
template <class TInputValue, class TTargetValue>
39
class ITK_EXPORT BoostMachineLearningModel
Julien Michel's avatar
Julien Michel committed
40 41 42 43 44 45 46 47 48
  : public MachineLearningModel <TInputValue, TTargetValue>
{
public:
  /** Standard class typedefs. */
  typedef BoostMachineLearningModel           Self;
  typedef MachineLearningModel<TInputValue, TTargetValue> Superclass;
  typedef itk::SmartPointer<Self>                         Pointer;
  typedef itk::SmartPointer<const Self>                   ConstPointer;

49 50 51 52 53 54
  typedef typename Superclass::InputValueType             InputValueType;
  typedef typename Superclass::InputSampleType            InputSampleType;
  typedef typename Superclass::InputListSampleType        InputListSampleType;
  typedef typename Superclass::TargetValueType            TargetValueType;
  typedef typename Superclass::TargetSampleType           TargetSampleType;
  typedef typename Superclass::TargetListSampleType       TargetListSampleType;
55
  typedef typename Superclass::ConfidenceValueType        ConfidenceValueType;
Julien Michel's avatar
Julien Michel committed
56 57 58

  /** Run-time type information (and related methods). */
  itkNewMacro(Self);
59
  itkTypeMacro(BoostMachineLearningModel, MachineLearningModel);
Julien Michel's avatar
Julien Michel committed
60

61 62 63
  /** Setters/Getters to the Boost type
   *  It can be CvBoost::DISCRETE, CvBoost::REAL, CvBoost::LOGIT, CvBoost::GENTLE
   *  Default is CvBoost::REAL.
64
   *  \see http://docs.opencv.org/modules/ml/doc/boosting.html#cvboostparams-cvboostparams
65 66 67 68 69 70 71
   */
  itkGetMacro(BoostType, int);
  itkSetMacro(BoostType, int);

  /** Setters/Getters to the split criteria
   *  It can be CvBoost::DEFAULT, CvBoost::GINI, CvBoost::MISCLASS, CvBoost::SQERR
   *  Default is CvBoost::DEFAULT. It uses default value according to \c BoostType
72
   *  \see http://docs.opencv.org/modules/ml/doc/boosting.html#cvboost-predict
73 74 75 76 77 78
   */
  itkGetMacro(SplitCrit, int);
  itkSetMacro(SplitCrit, int);

  /** Setters/Getters to the number of weak classifiers.
   *  Default is 100.
79
   *  \see http://docs.opencv.org/modules/ml/doc/boosting.html#cvboostparams-cvboostparams
80 81 82 83 84 85
   */
  itkGetMacro(WeakCount, int);
  itkSetMacro(WeakCount, int);

  /** Setters/Getters to the threshold WeightTrimRate.
   *  A threshold between 0 and 1 used to save computational time.
Guillaume Pasero's avatar
Guillaume Pasero committed
86
   *  Samples with summary weight \f$ w \leq 1 - WeightTrimRate \f$ do not participate in the next iteration of training.
87 88
   *  Set this parameter to 0 to turn off this functionality.
   *  Default is 0.95
89
   *  \see http://docs.opencv.org/modules/ml/doc/boosting.html#cvboostparams-cvboostparams
90 91 92 93 94 95
   */
  itkGetMacro(WeightTrimRate, double);
  itkSetMacro(WeightTrimRate, double);

  /** Setters/Getters to the maximum depth of the tree.
   * Default is 1
96
   * \see http://docs.opencv.org/modules/ml/doc/decision_trees.html#CvDTreeParams::CvDTreeParams%28%29
97 98 99 100
   */
  itkGetMacro(MaxDepth, int);
  itkSetMacro(MaxDepth, int);

101
  /** Train the machine learning model */
102
  void Train() override;
103

Julien Michel's avatar
Julien Michel committed
104
  /** Save the model to file */
105
  void Save(const std::string & filename, const std::string & name="") override;
Julien Michel's avatar
Julien Michel committed
106 107

  /** Load the model from file */
108
  void Load(const std::string & filename, const std::string & name="") override;
Julien Michel's avatar
Julien Michel committed
109

110 111 112
  /**\name Classification model file compatibility tests */
  //@{
  /** Is the input model file readable and compatible with the corresponding classifier ? */
113
  bool CanReadFile(const std::string &) override;
114 115

  /** Is the input model file writable and compatible with the corresponding classifier ? */
116
  bool CanWriteFile(const std::string &) override;
117
  //@}
118

Julien Michel's avatar
Julien Michel committed
119 120 121 122 123
protected:
  /** Constructor */
  BoostMachineLearningModel();

  /** Destructor */
124
  ~BoostMachineLearningModel() override;
Julien Michel's avatar
Julien Michel committed
125

126
  /** Predict values using the model */
127
  TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=nullptr) const override;
128 129

  
Julien Michel's avatar
Julien Michel committed
130
  /** PrintSelf method */
131
  void PrintSelf(std::ostream& os, itk::Indent indent) const override;
Julien Michel's avatar
Julien Michel committed
132 133

private:
134 135
  BoostMachineLearningModel(const Self &) = delete;
  void operator =(const Self&) = delete;
Julien Michel's avatar
Julien Michel committed
136

137
#ifdef OTB_OPENCV_3
138
  cv::Ptr<cv::ml::Boost> m_BoostModel;
139
#else
Julien Michel's avatar
Julien Michel committed
140
  CvBoost * m_BoostModel;
141
#endif
142 143 144
  int m_BoostType;
  int m_WeakCount;
  double m_WeightTrimRate;
145
  int m_SplitCrit;
146
  int m_MaxDepth;
Julien Michel's avatar
Julien Michel committed
147 148 149 150
};
} // end namespace otb

#ifndef OTB_MANUAL_INSTANTIATION
151
#include "otbBoostMachineLearningModel.hxx"
Julien Michel's avatar
Julien Michel committed
152 153 154
#endif

#endif