Commit 9d1ef865 authored by Arnaud Jaen's avatar Arnaud Jaen

ENH: Add KNN classification in the TrainMachineLearningImagesClassifier application.

parent 7c5adaf9
......@@ -38,7 +38,7 @@ IF(OTB_USE_OPENCV)
OTB_CREATE_APPLICATION(NAME TrainMachineLearningImagesClassifier
SOURCES otbTrainMachineLearningImagesClassifier.cxx otbTrainSVM.cxx otbTrainLibSVM.cxx otbTrainBoost.cxx
otbTrainDecisionTree.cxx otbTrainGradientBoostedTree.cxx otbTrainNeuralNetwork.cxx otbTrainNormalBayes.cxx
otbTrainRandomForests.cxx
otbTrainRandomForests.cxx otbTrainKNN.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters;OTBFeatureExtraction;OTBLearning;OTBMachineLearning)
OTB_CREATE_APPLICATION(NAME ImageClassifier
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbTrainMachineLearningImagesClassifier.h"
namespace otb
{
namespace Wrapper
{
void TrainMachineLearningImagesClassifier::InitKNNParams()
{
AddChoice("classifier.knn", "KNN classifier");
SetParameterDescription("classifier.knn", "This group of parameters allows to set KNN classifier parameters."
"See complete documentation here http://docs.opencv.org/modules/ml/doc/k_nearest_neighbors.html");
//K parameter
AddParameter(ParameterType_Int, "classifier.knn.k", "Number of Neighbors");
SetParameterInt("classifier.knn.k", 32);
SetParameterDescription("classifier.knn.k","The number of neighbors to used.");
}
void TrainMachineLearningImagesClassifier::TrainKNN(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample)
{
KNNType::Pointer knnClassifier = KNNType::New();
knnClassifier->SetInputListSample(trainingListSample);
knnClassifier->SetTargetListSample(trainingLabeledListSample);
knnClassifier->SetK(GetParameterInt("classifier.knn.k"));
knnClassifier->Train();
knnClassifier->Save(GetParameterString("io.out"));
}
} //end namespace wrapper
} //end namespace otb
......@@ -404,10 +404,10 @@ namespace Wrapper
{
TrainRandomForests(trainingListSample, trainingLabeledListSample);
}
/*else if (classifierType == "knn")
else if (classifierType == "knn")
{
TrainKNN(trainingListSample, trainingLabeledListSample);
} */
}
//--------------------------
// Performances estimation
......
......@@ -155,7 +155,7 @@ private:
void InitNeuralNetworkParams();
void InitNormalBayesParams();
void InitRandomForestsParams();
void InitKNNParams(){}
void InitKNNParams();
void TrainLibSVM(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainBoost(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
......@@ -165,7 +165,7 @@ private:
void TrainNeuralNetwork(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainNormalBayes(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainRandomForests(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void TrainKNN(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample){}
void TrainKNN(ListSampleType::Pointer trainingListSample, LabelListSampleType::Pointer trainingLabeledListSample);
void Classify(ListSampleType::Pointer validationListSample, LabelListSampleType::Pointer predictedList);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment