diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h index 6fd04d059d43c7be3450bc51fcf4a17e457e2c1b..145aecb7382e8c3a740885e73376156830034502 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h @@ -49,6 +49,7 @@ #ifdef OTB_USE_SHARK #include "otbSharkRandomForestsMachineLearningModel.h" +#include "otbSharkKMeansMachineLearningModel.h" #endif namespace otb @@ -139,6 +140,7 @@ public: #ifdef OTB_USE_SHARK typedef otb::SharkRandomForestsMachineLearningModel<InputValueType, OutputValueType> SharkRandomForestType; + typedef otb::SharkKMeansMachineLearningModel<InputValueType, OutputValueType> SharkKMeansType; #endif protected: @@ -221,6 +223,10 @@ private: void TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample, typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath); + void InitSharkKMeansParams(); + void TrainSharkKMeans(typename ListSampleType::Pointer trainingListSample, + typename TargetListSampleType::Pointer trainingLabeledListSample, + std::string modelPath); #endif //@} }; @@ -247,6 +253,7 @@ private: #endif #ifdef OTB_USE_SHARK #include "otbTrainSharkRandomForests.txx" +#include "otbTrainSharkKMeans.txx" #endif #endif diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx index bf110c82738e1223469bb2e0c486e6f911e5afd5..f3731e9495abbf638a28d6d7b516935ffc043eda 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx @@ -76,6 +76,7 @@ LearningApplicationBase<TInputValue,TOutputValue> #ifdef OTB_USE_SHARK InitSharkRandomForestsParams(); + InitSharkKMeansParams(); #endif } @@ -147,6 +148,15 @@ LearningApplicationBase<TInputValue,TOutputValue> otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration."); #endif } + if(modelName == "sharkkm") + { + #ifdef OTB_USE_SHARK + TrainSharkKMeans( trainingListSample, trainingLabeledListSample, modelPath ); + #else + otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration."); + #endif + } + // OpenCV SVM implementation is buggy with linear kernel // Users should use the libSVM implementation instead. diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx new file mode 100644 index 0000000000000000000000000000000000000000..bd96fe35684b517a5b8f7318938441c3727f5478 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx @@ -0,0 +1,71 @@ +/*========================================================================= + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#ifndef otbTrainSharkKMeans_txx +#define otbTrainSharkKMeans_txx + +#include "otbLearningApplicationBase.h" + +namespace otb +{ +namespace Wrapper +{ +template<class TInputValue, class TOutputValue> +void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() +{ + AddChoice( "classifier.sharkkm", "Shark kmeans classifier" ); + SetParameterDescription( "classifier.sharkkm", + "This group of parameters allows setting Shark kMeans classifier parameters. " + "See complete documentation here " + "\\url{http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html}.\n " ); + //MaxNumberOfIterations + AddParameter( ParameterType_Int, "classifier.sharkkm.nbmaxiter", + "Maximum number of iteration for the kmeans algorithm." ); + SetParameterInt( "classifier.sharkkm.nbmaxiter", 0 ); + SetMinimumParameterIntValue( "classifier.sharkkm.nbmaxiter", 0 ); + SetParameterDescription( "classifier.sharkkm.nbmaxiter", + "The maximum number of iteration for the kmeans algorithm. Default set to unlimited." ); + + //MaxNumberOfIterations + AddParameter( ParameterType_Int, "classifier.sharkkm.k", "The number of class used for the kmeans algorithm." ); + SetParameterInt( "classifier.sharkkm.k", 2 ); + SetParameterDescription( "classifier.sharkkm.k", + "The number of class used for the kmeans algorithm. Default set to 2 class" ); + SetMinimumParameterIntValue( "classifier.sharkkm.k", 2 ); +} + +template<class TInputValue, class TOutputValue> +void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans( + typename ListSampleType::Pointer trainingListSample, + typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath) +{ + unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.nbmaxiter" ) )); + unsigned int k = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.k" ) )); + + typename SharkKMeansType::Pointer classifier = SharkKMeansType::New(); + classifier->SetRegressionMode( this->m_RegressionFlag ); + classifier->SetInputListSample( trainingListSample ); + classifier->SetTargetListSample( trainingLabeledListSample ); + classifier->SetK( k ); + classifier->SetMaximumNumberOfIterations( nbMaxIter ); + classifier->Train(); + classifier->Save( modelPath ); +} + +} //end namespace wrapper +} //end namespace otb + +#endif