...
 
......@@ -22,6 +22,7 @@
#include "otbLearningApplicationBase.h"
#include "otbSharkKMeansMachineLearningModel.h"
#include "otbStatisticsXMLFileReader.h"
namespace otb
{
......@@ -45,9 +46,13 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroidstats", "Statistics file");
SetParameterDescription("classifier.sharkkm.centroidstats", "A XML file containing mean and standard deviation to center"
"and reduce the centroids before classification, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroidstats");
// Number of classes
AddParameter(ParameterType_String, "classifier.sharkkm.incentroid", "Number of classes for the kmeans algorithm");
AddParameter(ParameterType_String, "classifier.sharkkm.centroids", "Number of classes for the kmeans algorithm");
SetParameterDescription("classifier.sharkkm.incentroid", "The number of classes used for the kmeans algorithm. Default set to 2 class");
MandatoryOff("classifier.sharkkm.incentroid");
}
......@@ -66,7 +71,41 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->SetInputListSample( trainingListSample );
classifier->SetTargetListSample( trainingLabeledListSample );
classifier->SetK( k );
classifier->SetCentroidFilename( GetParameterString( "classifier.sharkkm.incentroid") );
// Initialize centroids from file
if(HasValue("classifier.sharkkm.centroids"))
{
shark::Data<shark::RealVector> centroidData;
shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroidstats"), ' ');
if( HasValue( "classifier.sharkkm.centroids" ) )
{
auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New();
statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroidstats" ));
auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
// Convert itk Variable Length Vector to shark Real Vector
shark::RealVector meanMeasurementRV(meanMeasurementVector.Size());
for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i)
{
// Substract the mean
meanMeasurementRV[i] = - meanMeasurementVector[i];
}
shark::RealVector stddevMeasurementRV(stddevMeasurementVector.Size());
for (unsigned int i = 0; i<stddevMeasurementVector.Size(); ++i)
{
stddevMeasurementRV[i] = stddevMeasurementVector[i];
}
shark::Normalizer<> normalizer(stddevMeasurementRV, meanMeasurementRV);
centroidData = normalizer(centroidData);
}
classifier->SetCentroidsFromData( centroidData);
}
classifier->SetMaximumNumberOfIterations( nbMaxIter );
classifier->Train();
classifier->Save( modelPath );
......
......@@ -48,6 +48,7 @@
#include "shark/Models/Clustering/Centroids.h"
#include "shark/Models/Clustering/ClusteringModel.h"
#include "shark/Algorithms/KMeans.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
......@@ -132,6 +133,13 @@ public:
itkGetMacro( CentroidFilename, std::string );
itkSetMacro( CentroidFilename, std::string );
/** Initialize the centroids for the kmeans algorithm */
void SetCentroidsFromData(const shark::Data<shark::RealVector> & data)
{
m_Centroids.setCentroids(data);
this->Modified();
}
protected:
/** Constructor */
SharkKMeansMachineLearningModel();
......@@ -148,6 +156,9 @@ protected:
template<typename DataType>
DataType NormalizeData(const DataType &data) const;
template<typename DataType>
shark::Normalizer<> TrainNormalizer(const DataType &data) const;
/** PrintSelf method */
void PrintSelf(std::ostream &os, itk::Indent indent) const override;
......
......@@ -39,7 +39,6 @@
#include "shark/Algorithms/KMeans.h" //k-means algorithm
#include "shark/Models/Clustering/HardClusteringModel.h"
#include "shark/Models/Clustering/SoftClusteringModel.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#include <shark/Data/Csv.h> //load the csv file
#if defined(__GNUC__) || defined(__clang__)
......@@ -67,18 +66,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
{
}
template<class TInputValue, class TOutputValue>
bool
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::InitializeCentroids()
{
shark::Data<shark::RealVector> data;
shark::importCSV(data, m_CentroidFilename, ' ');
m_Centroids.setCentroids(data);
std::cout <<m_Centroids.centroids() << std::endl;
return 1;
}
/** Train the machine learning model */
template<class TInputValue, class TOutputValue>
void
......@@ -90,12 +77,12 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data );
shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data );
if (!m_CentroidFilename.empty())
InitializeCentroids();
// Normalized input value if necessary
if( m_Normalized )
data = NormalizeData( data );
{
auto normalizer = TrainNormalizer(data);
data = normalizer(data);
}
// Use a Hard Clustering Model for classification
shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations );
......@@ -114,6 +101,18 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
return normalizer( data );
}
template<class TInputValue, class TOutputValue>
template<typename DataType>
shark::Normalizer<>
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::TrainNormalizer(const DataType &data) const
{
shark::Normalizer<> normalizer;
shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean
normalizingTrainer.train( normalizer, data );
return normalizer;
}
template<class TInputValue, class TOutputValue>
typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::TargetSampleType
......