Skip to content
Snippets Groups Projects
Commit 4cfa44ef authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

ENH: Add Shark KMeans Model into the learning application.

parent ee38c6aa
No related branches found
No related tags found
No related merge requests found
......@@ -49,6 +49,7 @@
#ifdef OTB_USE_SHARK
#include "otbSharkRandomForestsMachineLearningModel.h"
#include "otbSharkKMeansMachineLearningModel.h"
#endif
namespace otb
......@@ -139,6 +140,7 @@ public:
#ifdef OTB_USE_SHARK
typedef otb::SharkRandomForestsMachineLearningModel<InputValueType, OutputValueType> SharkRandomForestType;
typedef otb::SharkKMeansMachineLearningModel<InputValueType, OutputValueType> SharkKMeansType;
#endif
protected:
......@@ -221,6 +223,10 @@ private:
void TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
void InitSharkKMeansParams();
void TrainSharkKMeans(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
#endif
//@}
};
......@@ -247,6 +253,7 @@ private:
#endif
#ifdef OTB_USE_SHARK
#include "otbTrainSharkRandomForests.txx"
#include "otbTrainSharkKMeans.txx"
#endif
#endif
......
......@@ -76,6 +76,7 @@ LearningApplicationBase<TInputValue,TOutputValue>
#ifdef OTB_USE_SHARK
InitSharkRandomForestsParams();
InitSharkKMeansParams();
#endif
}
......@@ -147,6 +148,15 @@ LearningApplicationBase<TInputValue,TOutputValue>
otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
#endif
}
if(modelName == "sharkkm")
{
#ifdef OTB_USE_SHARK
TrainSharkKMeans( trainingListSample, trainingLabeledListSample, modelPath );
#else
otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
#endif
}
// OpenCV SVM implementation is buggy with linear kernel
// Users should use the libSVM implementation instead.
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef otbTrainSharkKMeans_txx
#define otbTrainSharkKMeans_txx
#include "otbLearningApplicationBase.h"
namespace otb
{
namespace Wrapper
{
template<class TInputValue, class TOutputValue>
void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
{
AddChoice( "classifier.sharkkm", "Shark kmeans classifier" );
SetParameterDescription( "classifier.sharkkm",
"This group of parameters allows setting Shark kMeans classifier parameters. "
"See complete documentation here "
"\\url{http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html}.\n " );
//MaxNumberOfIterations
AddParameter( ParameterType_Int, "classifier.sharkkm.nbmaxiter",
"Maximum number of iteration for the kmeans algorithm." );
SetParameterInt( "classifier.sharkkm.nbmaxiter", 0 );
SetMinimumParameterIntValue( "classifier.sharkkm.nbmaxiter", 0 );
SetParameterDescription( "classifier.sharkkm.nbmaxiter",
"The maximum number of iteration for the kmeans algorithm. Default set to unlimited." );
//MaxNumberOfIterations
AddParameter( ParameterType_Int, "classifier.sharkkm.k", "The number of class used for the kmeans algorithm." );
SetParameterInt( "classifier.sharkkm.k", 2 );
SetParameterDescription( "classifier.sharkkm.k",
"The number of class used for the kmeans algorithm. Default set to 2 class" );
SetMinimumParameterIntValue( "classifier.sharkkm.k", 2 );
}
template<class TInputValue, class TOutputValue>
void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample, std::string modelPath)
{
unsigned int nbMaxIter = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.nbmaxiter" ) ));
unsigned int k = static_cast<unsigned int>(abs( GetParameterInt( "classifier.sharkkm.k" ) ));
typename SharkKMeansType::Pointer classifier = SharkKMeansType::New();
classifier->SetRegressionMode( this->m_RegressionFlag );
classifier->SetInputListSample( trainingListSample );
classifier->SetTargetListSample( trainingLabeledListSample );
classifier->SetK( k );
classifier->SetMaximumNumberOfIterations( nbMaxIter );
classifier->Train();
classifier->Save( modelPath );
}
} //end namespace wrapper
} //end namespace otb
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment