Commit 98a40235 authored by Cédric Traizet's avatar Cédric Traizet

Merge branch 'kmean_centroids' into 'develop'

KMeans input centroids

See merge request !470
parents b37d9f59 97d8b279
Pipeline #1261 passed with stage
in 10 minutes and 36 seconds
148.1360412249 176.9065574064 79.2367424483 275.6865470422
180.3646315623 255.4157568188 138.2565634726 656.5357728603
187.5074713392 256.7055784897 121.8671939978 115.8660938389
220.0887858502 326.8933399989 229.672560688 434.3589597278
515.191687488 834.8626368509 642.6102022528 814.8945435557
......@@ -77,13 +77,22 @@ protected:
MandatoryOff("ts");
AddParameter(ParameterType_Int, "maxit", "Maximum number of iterations");
SetParameterDescription("maxit", "Maximum number of iterations for the learning step.");
SetParameterDescription("maxit",
"Maximum number of iterations for the learning step."
" If this parameter is set to 0, the KMeans algorithm will not stop until convergence");
SetDefaultParameterInt("maxit", 1000);
MandatoryOff("maxit");
AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename");
SetParameterDescription("outmeans", "Output text file containing centroid positions");
MandatoryOff("outmeans");
AddParameter(ParameterType_Group, "centroids", "Centroids IO parameters");
SetParameterDescription("centroids", "Group of parameters for centroids IO.");
AddParameter(ParameterType_InputFilename, "centroids.in", "input centroids text file");
SetParameterDescription("centroids.in",
"Input text file containing centroid positions used to initialize the algorithm. "
"Each centroid must be described by p parameters, p being the number of bands in "
"the input image, and the number of centroids must be equal to the number of classes "
"(one centroid per line with values separated by spaces).");
MandatoryOff("centroids.in");
ShareKMSamplingParameters();
ConnectKMSamplingParams();
......@@ -99,6 +108,7 @@ protected:
{
ShareParameter("ram", "polystats.ram");
ShareParameter("sampler", "select.sampler");
ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out");
ShareParameter("vm", "polystats.mask", "Validity Mask",
"Validity mask, only non-zero pixels will be used to estimate KMeans modes.");
}
......@@ -248,6 +258,14 @@ protected:
GetParameterInt("maxit"));
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k",
GetParameterInt("nc"));
if (IsParameterEnabled("centroids.in") && HasValue("centroids.in"))
{
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in", GetParameterString("centroids.in"));
GetInternalApplication("training")
->SetParameterString("classifier.sharkkm.centroids.stats", GetInternalApplication("imgstats")->GetParameterString("out"));
}
if( IsParameterEnabled("rand"))
GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand"));
......@@ -276,55 +294,6 @@ protected:
ExecuteInternal( "classif" );
}
void CreateOutMeansFile(FloatVectorImageType *image,
const std::string &modelFileName,
unsigned int nbClasses)
{
if (IsParameterEnabled("outmeans"))
{
unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
unsigned int nbElements = nbClasses * nbBands;
// get the line in model file that contains the centroids positions
std::ifstream infile(modelFileName);
if(!infile)
{
itkExceptionMacro(<< "File: " << modelFileName << " couldn't be opened");
}
// get the line with the centroids (starts with "2 ")
std::string line, centroidLine;
while(std::getline(infile,line))
{
if (line.size() > 2 && line[0] == '2' && line[1] == ' ')
{
centroidLine = line;
break;
}
}
std::vector<std::string> centroidElm;
boost::split(centroidElm,centroidLine,boost::is_any_of(" "));
// remove the first elements, not the centroids positions
int nbWord = centroidElm.size();
int beginCentroid = nbWord-nbElements;
centroidElm.erase(centroidElm.begin(), centroidElm.begin()+beginCentroid);
// write in the output file
std::ofstream outfile;
outfile.open(GetParameterString("outmeans"));
for (unsigned int i = 0; i < nbClasses; i++)
{
for (unsigned int j = 0; j < nbBands; j++)
{
outfile << std::setw(8) << centroidElm[i * nbBands + j] << " ";
}
outfile << std::endl;
}
}
}
class KMeansFileNamesHandler
{
public:
......@@ -495,9 +464,6 @@ private:
// Compute a classification of the input image according to a model file
Superclass::KMeansClassif();
// Create the output text file containing centroids positions
Superclass::CreateOutMeansFile(GetParameterImage("in"), fileNames.modelFile, GetParameterInt("nc"));
// Remove all tempory files
if( GetParameterInt( "cleanup" ) )
{
......
......@@ -122,7 +122,10 @@ LearningApplicationBase<TInputValue,TOutputValue>
::InitUnsupervisedClassifierParams()
{
#ifdef OTB_USE_SHARK
InitSharkKMeansParams();
if (!m_RegressionFlag)
{
InitSharkKMeansParams(); // Regression not supported
}
#endif
}
......
......@@ -22,6 +22,7 @@
#include "otbLearningApplicationBase.h"
#include "otbSharkKMeansMachineLearningModel.h"
#include "otbStatisticsXMLFileReader.h"
namespace otb
{
......@@ -44,6 +45,30 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetParameterInt("classifier.sharkkm.k", 2);
SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
// Centroid IO
AddParameter( ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters" );
SetParameterDescription( "classifier.sharkkm.centroids", "Group of parameters for centroids IO." );
// Input centroids
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids");
SetParameterDescription("classifier.sharkkm.centroids.in", "Input text file containing centroid posistions used to initialize the algorithm. "
"Each centroid must be described by p parameters, p being the number of features in "
"the input vector data, and the number of centroids must be equal to the number of classes "
"(one centroid per line with values separated by spaces).");
MandatoryOff("classifier.sharkkm.centroids");
// Centroid statistics
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file");
SetParameterDescription("classifier.sharkkm.centroids.stats", "A XML file containing mean and standard deviation to center"
"and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroids.stats");
// Output centroids
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.centroids.out");
}
template<class TInputValue, class TOutputValue>
......@@ -60,9 +85,48 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->SetInputListSample( trainingListSample );
classifier->SetTargetListSample( trainingLabeledListSample );
classifier->SetK( k );
// Initialize centroids from file
if(IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in"))
{
shark::Data<shark::RealVector> centroidData;
shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids.in"), ' ');
if( HasValue( "classifier.sharkkm.centroids.stats" ) )
{
auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New();
statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroids.stats" ));
auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
// Convert itk Variable Length Vector to shark Real Vector
shark::RealVector offsetRV(meanMeasurementVector.Size());
shark::RealVector scaleRV(stddevMeasurementVector.Size());
assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size());
for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i)
{
scaleRV[i] = 1/stddevMeasurementVector[i];
// Substract the normalized mean
offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i];
}
shark::Normalizer<> normalizer(scaleRV, offsetRV);
centroidData = normalizer(centroidData);
}
if (centroidData.numberOfElements() != k)
otbAppLogWARNING( "The input centroid file will not be used because it contains " << centroidData.numberOfElements() <<
" points, which is different than from the requested number of class: " << k <<".");
classifier->SetCentroidsFromData( centroidData);
}
classifier->SetMaximumNumberOfIterations( nbMaxIter );
classifier->Train();
classifier->Save( modelPath );
if( HasValue( "classifier.sharkkm.centroids.out"))
classifier->ExportCentroids( GetParameterString( "classifier.sharkkm.centroids.out" ));
}
} //end namespace wrapper
......
......@@ -673,7 +673,7 @@ if(OTB_USE_SHARK)
-sampler periodic
-rand 121212
-nodatalabel 255
-outmeans ${TEMP}/apTvClKMeansImageClassificationFilterOutMeans.txt
-centroids.out ${TEMP}/apTvClKMeansImageClassificationFilterOutMeans.txt
-out ${TEMP}/apTvClKMeansImageClassificationFilterOutput.tif uint8
-cleanup 0
VALID --compare-image ${NOTOL}
......@@ -681,6 +681,25 @@ if(OTB_USE_SHARK)
${TEMP}/apTvClKMeansImageClassificationFilterOutput.tif )
endif()
if(OTB_USE_SHARK)
otb_test_application(NAME apTvClKMeansImageClassification_inputCentroids
APP KMeansClassification
OPTIONS -in ${INPUTDATA}/qb_RoadExtract.img
-ts 30000
-nc 5
-maxit 10000
-sampler periodic
-nodatalabel 255
-rand 121212
-centroids.in ${INPUTDATA}/Classification/KMeansInputCentroids.txt
-out ${TEMP}/apTvClKMeansImageClassificationInputCentroids.tif uint8
-cleanup 0
VALID --compare-image ${NOTOL}
${OTBAPP_BASELINE}/apTvClKMeansImageClassificationInputCentroids.tif
${TEMP}/apTvClKMeansImageClassificationInputCentroids.tif )
endif()
#----------- TrainImagesClassifier TESTS ----------------
if(OTB_USE_LIBSVM)
otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1_allOpt_InXML
......
......@@ -48,6 +48,7 @@
#include "shark/Models/Clustering/Centroids.h"
#include "shark/Models/Clustering/ClusteringModel.h"
#include "shark/Algorithms/KMeans.h"
#include "shark/Models/Normalizer.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
......@@ -124,9 +125,14 @@ public:
/** Set the number of class for the kMeans algorithm.*/
itkSetMacro( K, unsigned );
/** If true, normalized input data sample list */
itkGetMacro( Normalized, bool );
itkSetMacro( Normalized, bool );
/** Initialize the centroids for the kmeans algorithm */
void SetCentroidsFromData(const shark::Data<shark::RealVector>& data)
{
m_Centroids.setCentroids(data);
this->Modified();
}
void ExportCentroids(const std::string& filename);
protected:
/** Constructor */
......@@ -142,9 +148,6 @@ protected:
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size,
TargetListSampleType *, ConfidenceListSampleType * = nullptr, ProbaListSampleType * = nullptr) const override;
template<typename DataType>
DataType NormalizeData(const DataType &data) const;
/** PrintSelf method */
void PrintSelf(std::ostream &os, itk::Indent indent) const override;
......@@ -153,16 +156,13 @@ private:
void operator=(const Self &) = delete;
// Parameters set by the user
bool m_Normalized;
unsigned int m_K;
unsigned int m_MaximumNumberOfIterations;
bool m_CanRead;
/** Centroids results form kMeans */
shark::Centroids m_Centroids;
/** shark Model could be SoftClusteringModel or HardClusteringModel */
boost::shared_ptr<ClusteringModelType> m_ClusteringModel;
......
......@@ -35,11 +35,10 @@
#include "otb_shark.h"
#include "otbSharkUtils.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" //normalize
#include "shark/Algorithms/KMeans.h" //k-means algorithm
#include "shark/Models/Clustering/HardClusteringModel.h"
#include "shark/Models/Clustering/SoftClusteringModel.h"
#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h"
#include <shark/Data/Csv.h> //load the csv file
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
......@@ -52,7 +51,7 @@ namespace otb
template<class TInputValue, class TOutputValue>
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::SharkKMeansMachineLearningModel() :
m_Normalized( false ), m_K(2), m_MaximumNumberOfIterations( 10 )
m_K(2), m_MaximumNumberOfIterations( 10 )
{
// Default set HardClusteringModel
this->m_ConfidenceIndex = true;
......@@ -77,27 +76,11 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data );
shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data );
// Normalized input value if necessary
if( m_Normalized )
data = NormalizeData( data );
// Use a Hard Clustering Model for classification
shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations );
m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids );
}
template<class TInputValue, class TOutputValue>
template<typename DataType>
DataType
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::NormalizeData(const DataType &data) const
{
shark::Normalizer<> normalizer;
shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean
normalizingTrainer.train( normalizer, data );
return normalizer( data );
}
template<class TInputValue, class TOutputValue>
typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::TargetSampleType
......@@ -258,6 +241,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
return true;
}
template<class TInputValue, class TOutputValue>
void
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
::ExportCentroids(const std::string & filename)
{
shark::exportCSV(m_Centroids.centroids(), filename, ' ');
}
template<class TInputValue, class TOutputValue>
void
SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment