otbTrainKNN.txx 2.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
/*=========================================================================
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
17
18
#ifndef otbTrainKNN_txx
#define otbTrainKNN_txx
19
#include "otbLearningApplicationBase.h"
20
21
22
23
24

namespace otb
{
namespace Wrapper
{
25
26
27
28
29

  template <class TInputValue, class TOutputValue>
  void
  LearningApplicationBase<TInputValue,TOutputValue>
  ::InitKNNParams()
30
31
  {
    AddChoice("classifier.knn", "KNN classifier");
32
    SetParameterDescription("classifier.knn", "This group of parameters allows setting KNN classifier parameters. "
33
        "See complete documentation here \\url{http://docs.opencv.org/modules/ml/doc/k_nearest_neighbors.html}.");
34
35
36

    //K parameter
    AddParameter(ParameterType_Int, "classifier.knn.k", "Number of Neighbors");
37
    SetParameterInt("classifier.knn.k",32, false);
38
    SetParameterDescription("classifier.knn.k","The number of neighbors to use.");
39

40
41
42
43
44
45
46
47
48
49
50
51
    if (this->m_RegressionFlag)
      {
      // Decision rule : mean / median
      AddParameter(ParameterType_Choice, "classifier.knn.rule", "Decision rule");
      SetParameterDescription("classifier.knn.rule", "Decision rule for regression output");

      AddChoice("classifier.knn.rule.mean", "Mean of neighbors values");
      SetParameterDescription("classifier.knn.rule.mean","Returns the mean of neighbors values");

      AddChoice("classifier.knn.rule.median", "Median of neighbors values");
      SetParameterDescription("classifier.knn.rule.median","Returns the median of neighbors values");
      }
52
53
  }

54
55
56
  template <class TInputValue, class TOutputValue>
  void
  LearningApplicationBase<TInputValue,TOutputValue>
57
58
  ::TrainKNN(typename ListSampleType::Pointer trainingListSample,
             typename TargetListSampleType::Pointer trainingLabeledListSample,
59
             std::string modelPath)
60
  {
61
    typename KNNType::Pointer knnClassifier = KNNType::New();
62
    knnClassifier->SetRegressionMode(this->m_RegressionFlag);
63
64
65
    knnClassifier->SetInputListSample(trainingListSample);
    knnClassifier->SetTargetListSample(trainingLabeledListSample);
    knnClassifier->SetK(GetParameterInt("classifier.knn.k"));
66
67
68
69
70
71
72
73
74
75
76
77
    if (this->m_RegressionFlag)
      {
      std::string decision = this->GetParameterString("classifier.knn.rule");
      if (decision == "mean")
        {
        knnClassifier->SetDecisionRule(KNNType::KNN_MEAN);
        }
      else if (decision == "median")
        {
        knnClassifier->SetDecisionRule(KNNType::KNN_MEDIAN);
        }
      }
78
79

    knnClassifier->Train();
80
    knnClassifier->Save(modelPath);
81
  }
82

83
84
} //end namespace wrapper
} //end namespace otb
85
86

#endif