otbTrainDecisionTree.txx 5.11 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*=========================================================================
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
17
18
#ifndef otbTrainDecisionTree_txx
#define otbTrainDecisionTree_txx
19
#include "otbLearningApplicationBase.h"
20
21
22
23
24

namespace otb
{
namespace Wrapper
{
25
26
27
28
29

template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitDecisionTreeParams()
30
31
32
{
  AddChoice("classifier.dt", "Decision Tree classifier");
  SetParameterDescription("classifier.dt",
33
                          "This group of parameters allows setting Decision Tree classifier parameters. "
34
                          "See complete documentation here \\url{http://docs.opencv.org/modules/ml/doc/decision_trees.html}.");
35
36
  //MaxDepth
  AddParameter(ParameterType_Int, "classifier.dt.max", "Maximum depth of the tree");
37
  SetParameterInt("classifier.dt.max",65535, false);
38
  SetParameterDescription(
39
40
41
      "classifier.dt.max", "The training algorithm attempts to split each node while its depth is smaller than the maximum "
      "possible depth of the tree. The actual depth may be smaller if the other termination criteria are met, and/or "
      "if the tree is pruned.");
42
43
44

  //MinSampleCount
  AddParameter(ParameterType_Int, "classifier.dt.min", "Minimum number of samples in each node");
45
  SetParameterInt("classifier.dt.min",10, false);
46
47
  SetParameterDescription("classifier.dt.min", "If the number of samples in a node is smaller than this parameter, "
                          "then this node will not be split.");
48
49

  //RegressionAccuracy
50
  AddParameter(ParameterType_Float, "classifier.dt.ra", "Termination criteria for regression tree");
51
  SetParameterFloat("classifier.dt.ra",0.01, false);
52
53
  SetParameterDescription("classifier.dt.min", "If all absolute differences between an estimated value in a node "
                          "and the values of the train samples in this node are smaller than this regression accuracy parameter, "
54
55
56
57
                          "then the node will not be split.");

  //UseSurrogates : don't need to be exposed !
  //AddParameter(ParameterType_Empty, "classifier.dt.sur", "Surrogate splits will be built");
58
  //SetParameterDescription("classifier.dt.sur","These splits allow working with missing data and compute variable importance correctly.");
59
60
61
62

  //MaxCategories
  AddParameter(ParameterType_Int, "classifier.dt.cat",
               "Cluster possible values of a categorical variable into K <= cat clusters to find a suboptimal split");
63
  SetParameterInt("classifier.dt.cat",10, false);
64
65
  SetParameterDescription(
      "classifier.dt.cat",
66
      "Cluster possible values of a categorical variable into K <= cat clusters to find a suboptimal split.");
67
68
69

  //CVFolds
  AddParameter(ParameterType_Int, "classifier.dt.f", "K-fold cross-validations");
70
  SetParameterInt("classifier.dt.f",10, false);
71
  SetParameterDescription(
72
      "classifier.dt.f", "If cv_folds > 1, then it prunes a tree with K-fold cross-validation where K is equal to cv_folds.");
73
74

  //Use1seRule
75
  AddParameter(ParameterType_Empty, "classifier.dt.r", "Set Use1seRule flag to false");
76
77
  SetParameterDescription(
      "classifier.dt.r",
78
      "If true, then a pruning will be harsher. This will make a tree more compact and more resistant to the training data noise but a bit less accurate.");
79
80

  //TruncatePrunedTree
81
82
  AddParameter(ParameterType_Empty, "classifier.dt.t", "Set TruncatePrunedTree flag to false");
  SetParameterDescription("classifier.dt.t", "If true, then pruned branches are physically removed from the tree.");
83
84
85
86
87

  //Priors are not exposed.

}

88
89
90
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
91
92
::TrainDecisionTree(typename ListSampleType::Pointer trainingListSample,
                    typename TargetListSampleType::Pointer trainingLabeledListSample,
93
                    std::string modelPath)
94
{
95
  typename DecisionTreeType::Pointer classifier = DecisionTreeType::New();
96
  classifier->SetRegressionMode(this->m_RegressionFlag);
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  classifier->SetInputListSample(trainingListSample);
  classifier->SetTargetListSample(trainingLabeledListSample);
  classifier->SetMaxDepth(GetParameterInt("classifier.dt.max"));
  classifier->SetMinSampleCount(GetParameterInt("classifier.dt.min"));
  classifier->SetRegressionAccuracy(GetParameterFloat("classifier.dt.ra"));
  classifier->SetMaxCategories(GetParameterInt("classifier.dt.cat"));
  classifier->SetCVFolds(GetParameterInt("classifier.dt.f"));
  if (IsParameterEnabled("classifier.dt.r"))
    {
    classifier->SetUse1seRule(false);
    }
  if (IsParameterEnabled("classifier.dt.t"))
    {
    classifier->SetTruncatePrunedTree(false);
    }
  classifier->Train();
113
  classifier->Save(modelPath);
114
}
115

116
117
} //end namespace wrapper
} //end namespace otb
118
119

#endif