Skip to content
Snippets Groups Projects
Commit 3167cee4 authored by Julien Malik's avatar Julien Malik
Browse files

ADD: KMeansAttributesLabelMapFilter

parent d33cc2c2
No related branches found
No related tags found
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbKMeansAttributesLabelMapFilter_h
#define __otbKMeansAttributesLabelMapFilter_h
#include "itkLabelMapFilter.h"
#include "itkSimpleDataObjectDecorator.h"
#include "otbLabelMapWithClassLabelToLabeledSampleListFilter.h"
#include "itkListSample.h"
#include "itkEuclideanDistance.h"
#include "itkWeightedCentroidKdTreeGenerator.h"
#include "itkKdTreeBasedKmeansEstimator.h"
namespace otb {
/** \class KMeansAttributesLabelMapFilter
* \brief Execute a KMeans on the attributes of a itk::LabelMap<otb::AttributesMapLabelObject>
*/
template<class TInputImage>
class ITK_EXPORT KMeansAttributesLabelMapFilter :
public itk::Object
{
public:
/** Standard class typedefs. */
typedef KMeansAttributesLabelMapFilter Self;
typedef itk::LabelMapFilter<TInputImage, TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Some convenient typedefs. */
typedef TInputImage InputImageType;
typedef typename InputImageType::Pointer InputImagePointer;
typedef typename InputImageType::ConstPointer InputImageConstPointer;
typedef typename InputImageType::RegionType InputImageRegionType;
typedef typename InputImageType::PixelType InputImagePixelType;
typedef typename InputImageType::LabelObjectType LabelObjectType;
typedef itk::DataObject DataObjectType;
typedef DataObjectType::Pointer DataObjectPointerType;
// LabelObject attributes
typedef typename LabelObjectType::AttributesValueType AttributesValueType;
typedef typename LabelObjectType::AttributesMapType AttributesMapType;
typedef itk::SimpleDataObjectDecorator<AttributesMapType> AttributesMapObjectType;
typedef typename InputImageType::LabelObjectContainerType LabelObjectContainerType;
typedef typename LabelObjectContainerType::const_iterator LabelObjectContainerConstIterator;
typedef typename LabelObjectType::ClassLabelType ClassLabelType;
// LabelMapToSampleList
typedef itk::VariableLengthVector<AttributesValueType> VectorType;
typedef itk::FixedArray<ClassLabelType,1> ClassLabelVectorType;
typedef itk::Statistics::ListSample<VectorType> ListSampleType;
typedef itk::Statistics::ListSample<ClassLabelVectorType> TrainingListSampleType;
typedef otb::LabelMapWithClassLabelToLabeledSampleListFilter<
InputImageType,
ListSampleType,
TrainingListSampleType> LabelMapToSampleListFilterType;
typedef typename LabelMapToSampleListFilterType::MeasurementFunctorType MeasurementFunctorType;
// KMeans
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef typename TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
typedef itk::Statistics::EuclideanDistance<VectorType> DistanceType;
typedef std::vector<VectorType> CentroidsVectorType;
typedef itk::SimpleDataObjectDecorator<CentroidsVectorType> CentroidsVectorObjectType;
/** ImageDimension constants */
itkStaticConstMacro(InputImageDimension, unsigned int,
TInputImage::ImageDimension);
/** Standard New method. */
itkNewMacro(Self);
/** Runtime information support. */
itkTypeMacro(KMeansAttributesLabelMapFilter,
LabelMapFilter);
/** Return the centroids resulting from the KMeans */
CentroidsVectorType& GetCentroids()
{
return m_Centroids;
}
const CentroidsVectorType& GetCentroids() const
{
return m_Centroids;
}
itkSetObjectMacro(InputLabelMap,InputImageType);
itkGetObjectMacro(InputLabelMap,InputImageType);
/** Set the number of classes of the input sample list.
* It will be used to choose the number of centroids.
* In the one-class case, 10 centroids is chosen. Otherwise,
* a number of centroids equal to the number of classes */
itkSetMacro(NumberOfClasses, unsigned int);
itkGetMacro(NumberOfClasses, unsigned int);
MeasurementFunctorType& GetMeasurementFunctor()
{
return m_LabelMapToSampleListFilter->GetMeasurementFunctor();
}
void SetMeasurementFunctor(MeasurementFunctorType& functor)
{
m_LabelMapToSampleListFilter->SetMeasurementFunctor(functor);
}
void Compute();
protected:
KMeansAttributesLabelMapFilter();
~KMeansAttributesLabelMapFilter() {};
private:
KMeansAttributesLabelMapFilter(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
InputImagePointer m_InputLabelMap;
CentroidsVectorType m_Centroids;
typename LabelMapToSampleListFilterType::Pointer m_LabelMapToSampleListFilter;
unsigned int m_NumberOfClasses;
}; // end of class
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbKMeansAttributesLabelMapFilter.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbKMeansAttributesLabelMapFilter_txx
#define __otbKMeansAttributesLabelMapFilter_txx
#include "otbKMeansAttributesLabelMapFilter.h"
#include "itkNumericTraits.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
namespace otb {
template <class TInputImage>
KMeansAttributesLabelMapFilter<TInputImage>
::KMeansAttributesLabelMapFilter()
: m_LabelMapToSampleListFilter(LabelMapToSampleListFilterType::New()),
m_NumberOfClasses(1)
{
}
template<class TInputImage>
void
KMeansAttributesLabelMapFilter<TInputImage>
::Compute()
{
m_LabelMapToSampleListFilter->SetInputLabelMap(m_InputLabelMap);
m_LabelMapToSampleListFilter->Update();
typename ListSampleType::Pointer listSamples = m_LabelMapToSampleListFilter->GetOutputSampleList();
typename TrainingListSampleType::Pointer trainingSamples = m_LabelMapToSampleListFilter->GetOutputTrainingSampleList();
// Build the Kd Tree
typename TreeGeneratorType::Pointer kdTreeGenerator = TreeGeneratorType::New();
kdTreeGenerator->SetSample(listSamples);
kdTreeGenerator->SetBucketSize(100);
kdTreeGenerator->Update();
// Randomly pick the initial means among the classes
unsigned int sampleSize = listSamples->GetMeasurementVector(0).Size();
const unsigned int OneClassNbCentroids = 10;
unsigned int numberOfCentroids = (m_NumberOfClasses == 1 ? OneClassNbCentroids : m_NumberOfClasses);
typename EstimatorType::ParametersType initialMeans(sampleSize * m_NumberOfClasses);
initialMeans.Fill(0.);
if (m_NumberOfClasses > 1)
{
// For each class, choose a centroid as the first sample of this class encountered
for (ClassLabelType classLabel = 0; classLabel < m_NumberOfClasses; ++classLabel)
{
typename TrainingListSampleType::ConstIterator it;
// Iterate on the label list and stop when classLabel is found
// TODO: add random initialization ?
for (it = trainingSamples->Begin(); it != trainingSamples->End(); ++it)
{
if (it.GetMeasurementVector()[0] == classLabel)
break;
}
if (it == trainingSamples->End())
{
itkExceptionMacro(<<"Unable to find a sample with class label "<< classLabel);
}
typename ListSampleType::InstanceIdentifier identifier = it.GetInstanceIdentifier();
const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
for (unsigned int i = 0; i < centroid.Size(); ++i)
{
initialMeans[classLabel * sampleSize + i] = centroid[i];
}
}
}
else
{
typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType;
RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::New();
unsigned int nbLabelObjects = listSamples->Size();
// Choose arbitrarily OneClassNbCentroids centroids among all available LabelObject
for (unsigned int centroidId = 0; centroidId < numberOfCentroids; ++centroidId)
{
typename ListSampleType::InstanceIdentifier identifier = randomGenerator->GetIntegerVariate(nbLabelObjects - 1);
const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier);
for (unsigned int i = 0; i < centroid.Size(); ++i)
{
initialMeans[centroidId * sampleSize + i] = centroid[i];
}
}
}
// Run the KMeans algorithm
// Do KMeans estimation
typename EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetParameters(initialMeans);
estimator->SetKdTree(kdTreeGenerator->GetOutput());
estimator->SetMaximumIteration(10000);
estimator->SetCentroidPositionChangesThreshold(0.00001);
estimator->StartOptimization();
// Retrieve final centroids
m_Centroids.clear();
for(unsigned int cId = 0; cId < numberOfCentroids; ++cId)
{
VectorType newCenter(sampleSize);
for(unsigned int i = 0; i < sampleSize; ++i)
{
newCenter[i] = estimator->GetParameters()[cId * sampleSize + i];
}
m_Centroids.push_back(newCenter);
}
}
}// end namespace otb
#endif
......@@ -42,7 +42,7 @@ ADD_TEST(obTvImageToLabelMapWithAttributesFilter ${OBIA_TESTS1}
${INPUTDATA}/cala_labelled.tif)
ADD_TEST(obTuLabelMapSourceNew ${OBIA_TESTS1}
otbLabelMapSourceNew)
otbLabelMapSourceNew)
ADD_TEST(obTvLabelMapToVectorDataFilter ${OBIA_TESTS1}
otbLabelMapToVectorDataFilter
......@@ -54,7 +54,6 @@ ADD_TEST(obTvVectorDataToLabelMapFilter ${OBIA_TESTS1}
${INPUTDATA}/vectorIOexample_gis_to_vec.shp
${TEMP}/vectordataToLabelMap.png)
ADD_TEST(obTuLabelMapToSampleListFilterNew ${OBIA_TESTS1}
otbLabelMapToSampleListFilterNew)
......@@ -114,6 +113,14 @@ ADD_TEST(obTvNormalizeAttributesLabelMapFilter ${OBIA_TESTS1}
${INPUTDATA}/cala_labelled.tif
${TEMP}/obTvNormalizeAttributesLabelMapFilter.txt)
ADD_TEST(obTuKMeansAttributesLabelMapFilterNew ${OBIA_TESTS1}
otbKMeansAttributesLabelMapFilterNew)
ADD_TEST(obTvKMeansAttributesLabelMapFilter ${OBIA_TESTS1}
otbKMeansAttributesLabelMapFilter
${INPUTDATA}/calanques.tif
${INPUTDATA}/cala_labelled.tif
${TEMP}/obTvKMeansAttributesLabelMapFilter.txt)
ADD_TEST(obTuRadiometricAttributesLabelMapFilterNew ${OBIA_TESTS1}
otbRadiometricAttributesLabelMapFilterNew
......@@ -170,6 +177,7 @@ otbLabelObjectMapVectorizer.cxx
otbLabelObjectToPolygonFunctorNew.cxx
otbMinMaxAttributesLabelMapFilter.cxx
otbNormalizeAttributesLabelMapFilter.cxx
otbKMeansAttributesLabelMapFilter.cxx
otbRadiometricAttributesLabelMapFilterNew.cxx
otbShapeAttributesLabelMapFilterNew.cxx
otbStatisticsAttributesLabelMapFilterNew.cxx
......
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbImageFileReader.h"
#include <fstream>
#include <iostream>
#include "otbImage.h"
#include "otbVectorImage.h"
#include "otbAttributesMapLabelObjectWithClassLabel.h"
#include "itkLabelMap.h"
#include "itkLabelImageToLabelMapFilter.h"
#include "otbImageToLabelMapWithAttributesFilter.h"
#include "otbKMeansAttributesLabelMapFilter.h"
const unsigned int Dimension = 2;
typedef unsigned short LabelType;
typedef double PixelType;
typedef otb::AttributesMapLabelObjectWithClassLabel<LabelType, Dimension, double, LabelType>
LabelObjectType;
typedef itk::LabelMap<LabelObjectType> LabelMapType;
typedef otb::VectorImage<PixelType, Dimension> VectorImageType;
typedef otb::Image<unsigned int,2> LabeledImageType;
typedef otb::ImageFileReader<VectorImageType> ReaderType;
typedef otb::ImageFileReader<LabeledImageType> LabeledReaderType;
typedef itk::LabelImageToLabelMapFilter<LabeledImageType,LabelMapType> LabelMapFilterType;
typedef otb::ShapeAttributesLabelMapFilter<LabelMapType> ShapeFilterType;
typedef otb::KMeansAttributesLabelMapFilter<LabelMapType> KMeansAttributesLabelMapFilterType;
int otbKMeansAttributesLabelMapFilterNew(int argc, char * argv[])
{
KMeansAttributesLabelMapFilterType::Pointer radiometricLabelMapFilter = KMeansAttributesLabelMapFilterType::New();
return EXIT_SUCCESS;
}
int otbKMeansAttributesLabelMapFilter(int argc, char * argv[])
{
const char * infname = argv[1];
const char * lfname = argv[2];
const char * outfname = argv[3];
// SmartPointer instanciation
ReaderType::Pointer reader = ReaderType::New();
LabeledReaderType::Pointer labeledReader = LabeledReaderType::New();
LabelMapFilterType::Pointer filter = LabelMapFilterType::New();
ShapeFilterType::Pointer shapeFilter = ShapeFilterType::New();
KMeansAttributesLabelMapFilterType::Pointer kmeansLabelMapFilter = KMeansAttributesLabelMapFilterType::New();
// Inputs
reader->SetFileName(infname);
reader->UpdateOutputInformation();
labeledReader->SetFileName(lfname);
labeledReader->UpdateOutputInformation();
// Filter
filter->SetInput(labeledReader->GetOutput());
filter->SetBackgroundValue(itk::NumericTraits<LabelType>::max());
shapeFilter->SetInput(filter->GetOutput());
shapeFilter->Update();
// Labelize the objects
LabelMapType::Pointer labelMap = shapeFilter->GetOutput();
LabelMapType::LabelObjectContainerType& container = labelMap->GetLabelObjectContainer();
LabelMapType::LabelObjectContainerType::const_iterator loIt = container.begin();
unsigned int labelObjectID = 0;
for(loIt = container.begin(); loIt != container.end(); ++loIt )
{
unsigned int classLabel = labelObjectID % 3;
loIt->second->SetClassLabel(classLabel);
++labelObjectID;
}
kmeansLabelMapFilter->SetInputLabelMap(labelMap);
std::vector<std::string> attributes = labelMap->GetLabelObject(0)->GetAvailableAttributes();
std::vector<std::string>::const_iterator attrIt;
for (attrIt = attributes.begin(); attrIt != attributes.end(); ++attrIt)
{
kmeansLabelMapFilter->GetMeasurementFunctor().AddAttribute((*attrIt).c_str());
}
kmeansLabelMapFilter->SetNumberOfClasses(3);
kmeansLabelMapFilter->Compute();
std::ofstream outfile(outfname);
const KMeansAttributesLabelMapFilterType::CentroidsVectorType& centroids = kmeansLabelMapFilter->GetCentroids();
for (unsigned int i = 0; i < centroids.size(); ++i)
{
outfile << "Centroid " << i << " : " << centroids[i] << std::endl;
}
outfile.close();
return EXIT_SUCCESS;
}
......@@ -45,6 +45,8 @@ REGISTER_TEST(otbMinMaxAttributesLabelMapFilterNew);
REGISTER_TEST(otbMinMaxAttributesLabelMapFilter);
REGISTER_TEST(otbNormalizeAttributesLabelMapFilterNew);
REGISTER_TEST(otbNormalizeAttributesLabelMapFilter);
REGISTER_TEST(otbKMeansAttributesLabelMapFilterNew);
REGISTER_TEST(otbKMeansAttributesLabelMapFilter);
REGISTER_TEST(otbRadiometricAttributesLabelMapFilterNew);
REGISTER_TEST(otbShapeAttributesLabelMapFilterNew);
REGISTER_TEST(otbStatisticsAttributesLabelMapFilterNew);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment