otbTrainGradientBoostedTree.txx 4.76 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*=========================================================================
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
17
18
#ifndef otbTrainGradientBoostedTree_txx
#define otbTrainGradientBoostedTree_txx
19
#include "otbLearningApplicationBase.h"
20
21
22
23
24

namespace otb
{
namespace Wrapper
{
25
26
27
28
29

template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitGradientBoostedTreeParams()
30
31
32
33
{
  AddChoice("classifier.gbt", "Gradient Boosted Tree classifier");
  SetParameterDescription(
      "classifier.gbt",
34
      "This group of parameters allows setting Gradient Boosted Tree classifier parameters. "
35
      "See complete documentation here \\url{http://docs.opencv.org/modules/ml/doc/gradient_boosted_trees.html}.");
36
37
38
39
40
41
42
43
44

  if (m_RegressionFlag)
    {
    AddParameter(ParameterType_Choice, "classifier.gbt.t", "Loss Function Type");
    SetParameterDescription("classifier.gbt.t","Type of loss functionused for training.");
    AddChoice("classifier.gbt.t.sqr","Squared Loss");
    AddChoice("classifier.gbt.t.abs","Absolute Loss");
    AddChoice("classifier.gbt.t.hub","Huber Loss");
    }
45
46

  //WeakCount
47
  AddParameter(ParameterType_Int, "classifier.gbt.w", "Number of boosting algorithm iterations");
48
  SetParameterInt("classifier.gbt.w",200, false);
49
50
  SetParameterDescription(
      "classifier.gbt.w",
51
52
      "Number \"w\" of boosting algorithm iterations, with w*K being the total number of trees in "
      "the GBT model, where K is the output number of classes.");
53
54

  //Shrinkage
55
  AddParameter(ParameterType_Float, "classifier.gbt.s", "Regularization parameter");
56
  SetParameterFloat("classifier.gbt.s",0.01, false);
57
58
59
60
  SetParameterDescription("classifier.gbt.s", "Regularization parameter.");

  //SubSamplePortion
  AddParameter(ParameterType_Float, "classifier.gbt.p",
61
               "Portion of the whole training set used for each algorithm iteration");
62
  SetParameterFloat("classifier.gbt.p",0.8, false);
63
64
  SetParameterDescription(
      "classifier.gbt.p",
65
      "Portion of the whole training set used for each algorithm iteration. The subset is generated randomly.");
66
67
68

  //MaxDepth
  AddParameter(ParameterType_Int, "classifier.gbt.max", "Maximum depth of the tree");
69
  SetParameterInt("classifier.gbt.max",3, false);
70
  SetParameterDescription(
71
72
73
        "classifier.gbt.max", "The training algorithm attempts to split each node while its depth is smaller than the maximum "
        "possible depth of the tree. The actual depth may be smaller if the other termination criteria are met, and/or "
        "if the tree is pruned.");
74
75
76

  //UseSurrogates : don't need to be exposed !
  //AddParameter(ParameterType_Empty, "classifier.gbt.sur", "Surrogate splits will be built");
77
  //SetParameterDescription("classifier.gbt.sur","These splits allow working with missing data and compute variable importance correctly.");
78
79
80

}

81
82
83
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
84
85
::TrainGradientBoostedTree(typename ListSampleType::Pointer trainingListSample,
                           typename TargetListSampleType::Pointer trainingLabeledListSample,
86
                           std::string modelPath)
87
{
88
  typename GradientBoostedTreeType::Pointer classifier = GradientBoostedTreeType::New();
89
  classifier->SetRegressionMode(this->m_RegressionFlag);
90
91
92
93
94
95
96
  classifier->SetInputListSample(trainingListSample);
  classifier->SetTargetListSample(trainingLabeledListSample);
  classifier->SetWeakCount(GetParameterInt("classifier.gbt.w"));
  classifier->SetShrinkage(GetParameterFloat("classifier.gbt.s"));
  classifier->SetSubSamplePortion(GetParameterFloat("classifier.gbt.p"));
  classifier->SetMaxDepth(GetParameterInt("classifier.gbt.max"));

97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  if (m_RegressionFlag)
    {
    switch (GetParameterInt("classifier.gbt.t"))
      {
      case 0: // SQUARED_LOSS
        classifier->SetLossFunctionType(CvGBTrees::SQUARED_LOSS);
        break;
      case 1: // ABSOLUTE_LOSS
        classifier->SetLossFunctionType(CvGBTrees::ABSOLUTE_LOSS);
        break;
      case 2: // HUBER_LOSS
        classifier->SetLossFunctionType(CvGBTrees::HUBER_LOSS);
        break;
      default:
        classifier->SetLossFunctionType(CvGBTrees::SQUARED_LOSS);
        break;
      }
    }
  else
    {
    classifier->SetLossFunctionType(CvGBTrees::DEVIANCE_LOSS);
    }

120
  classifier->Train();
121
  classifier->Save(modelPath);
122
}
123

124
125
} //end namespace wrapper
} //end namespace otb
126
127

#endif