diff --git a/Code/OBIA/otbKMeansAttributesLabelMapFilter.h b/Code/OBIA/otbKMeansAttributesLabelMapFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..4ea57c5a6283e46de97e4d4d75238738d25c9512 --- /dev/null +++ b/Code/OBIA/otbKMeansAttributesLabelMapFilter.h @@ -0,0 +1,151 @@ +/*========================================================================= + +Program: ORFEO Toolbox +Language: C++ +Date: $Date$ +Version: $Revision$ + + +Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. +See OTBCopyright.txt for details. + + +This software is distributed WITHOUT ANY WARRANTY; without even +the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbKMeansAttributesLabelMapFilter_h +#define __otbKMeansAttributesLabelMapFilter_h + +#include "itkLabelMapFilter.h" +#include "itkSimpleDataObjectDecorator.h" +#include "otbLabelMapWithClassLabelToLabeledSampleListFilter.h" +#include "itkListSample.h" +#include "itkEuclideanDistance.h" +#include "itkWeightedCentroidKdTreeGenerator.h" +#include "itkKdTreeBasedKmeansEstimator.h" + +namespace otb { + +/** \class KMeansAttributesLabelMapFilter + * \brief Execute a KMeans on the attributes of a itk::LabelMap<otb::AttributesMapLabelObject> + */ +template<class TInputImage> +class ITK_EXPORT KMeansAttributesLabelMapFilter : + public itk::Object +{ +public: + /** Standard class typedefs. */ + typedef KMeansAttributesLabelMapFilter Self; + typedef itk::LabelMapFilter<TInputImage, TInputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Some convenient typedefs. */ + typedef TInputImage InputImageType; + typedef typename InputImageType::Pointer InputImagePointer; + typedef typename InputImageType::ConstPointer InputImageConstPointer; + typedef typename InputImageType::RegionType InputImageRegionType; + typedef typename InputImageType::PixelType InputImagePixelType; + typedef typename InputImageType::LabelObjectType LabelObjectType; + typedef itk::DataObject DataObjectType; + typedef DataObjectType::Pointer DataObjectPointerType; + + // LabelObject attributes + typedef typename LabelObjectType::AttributesValueType AttributesValueType; + typedef typename LabelObjectType::AttributesMapType AttributesMapType; + typedef itk::SimpleDataObjectDecorator<AttributesMapType> AttributesMapObjectType; + typedef typename InputImageType::LabelObjectContainerType LabelObjectContainerType; + typedef typename LabelObjectContainerType::const_iterator LabelObjectContainerConstIterator; + typedef typename LabelObjectType::ClassLabelType ClassLabelType; + + // LabelMapToSampleList + typedef itk::VariableLengthVector<AttributesValueType> VectorType; + typedef itk::FixedArray<ClassLabelType,1> ClassLabelVectorType; + + typedef itk::Statistics::ListSample<VectorType> ListSampleType; + typedef itk::Statistics::ListSample<ClassLabelVectorType> TrainingListSampleType; + typedef otb::LabelMapWithClassLabelToLabeledSampleListFilter< + InputImageType, + ListSampleType, + TrainingListSampleType> LabelMapToSampleListFilterType; + typedef typename LabelMapToSampleListFilterType::MeasurementFunctorType MeasurementFunctorType; + + // KMeans + typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType; + typedef typename TreeGeneratorType::KdTreeType TreeType; + typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType; + typedef itk::Statistics::EuclideanDistance<VectorType> DistanceType; + typedef std::vector<VectorType> CentroidsVectorType; + typedef itk::SimpleDataObjectDecorator<CentroidsVectorType> CentroidsVectorObjectType; + + /** ImageDimension constants */ + itkStaticConstMacro(InputImageDimension, unsigned int, + TInputImage::ImageDimension); + + /** Standard New method. */ + itkNewMacro(Self); + + /** Runtime information support. */ + itkTypeMacro(KMeansAttributesLabelMapFilter, + LabelMapFilter); + + /** Return the centroids resulting from the KMeans */ + CentroidsVectorType& GetCentroids() + { + return m_Centroids; + } + const CentroidsVectorType& GetCentroids() const + { + return m_Centroids; + } + + itkSetObjectMacro(InputLabelMap,InputImageType); + itkGetObjectMacro(InputLabelMap,InputImageType); + + /** Set the number of classes of the input sample list. + * It will be used to choose the number of centroids. + * In the one-class case, 10 centroids is chosen. Otherwise, + * a number of centroids equal to the number of classes */ + itkSetMacro(NumberOfClasses, unsigned int); + itkGetMacro(NumberOfClasses, unsigned int); + + MeasurementFunctorType& GetMeasurementFunctor() + { + return m_LabelMapToSampleListFilter->GetMeasurementFunctor(); + } + + void SetMeasurementFunctor(MeasurementFunctorType& functor) + { + m_LabelMapToSampleListFilter->SetMeasurementFunctor(functor); + } + + void Compute(); + +protected: + KMeansAttributesLabelMapFilter(); + ~KMeansAttributesLabelMapFilter() {}; + + +private: + KMeansAttributesLabelMapFilter(const Self&); //purposely not implemented + void operator=(const Self&); //purposely not implemented + + InputImagePointer m_InputLabelMap; + CentroidsVectorType m_Centroids; + + typename LabelMapToSampleListFilterType::Pointer m_LabelMapToSampleListFilter; + unsigned int m_NumberOfClasses; + +}; // end of class + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbKMeansAttributesLabelMapFilter.txx" +#endif + +#endif + + diff --git a/Code/OBIA/otbKMeansAttributesLabelMapFilter.txx b/Code/OBIA/otbKMeansAttributesLabelMapFilter.txx new file mode 100644 index 0000000000000000000000000000000000000000..b94b82dc5c6a466da8a9144079bd0d401be5e60a --- /dev/null +++ b/Code/OBIA/otbKMeansAttributesLabelMapFilter.txx @@ -0,0 +1,126 @@ +/*========================================================================= + +Program: ORFEO Toolbox +Language: C++ +Date: $Date$ +Version: $Revision$ + + +Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. +See OTBCopyright.txt for details. + + +This software is distributed WITHOUT ANY WARRANTY; without even +the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbKMeansAttributesLabelMapFilter_txx +#define __otbKMeansAttributesLabelMapFilter_txx + +#include "otbKMeansAttributesLabelMapFilter.h" +#include "itkNumericTraits.h" +#include "itkMersenneTwisterRandomVariateGenerator.h" + +namespace otb { + +template <class TInputImage> +KMeansAttributesLabelMapFilter<TInputImage> +::KMeansAttributesLabelMapFilter() + : m_LabelMapToSampleListFilter(LabelMapToSampleListFilterType::New()), + m_NumberOfClasses(1) +{ +} + +template<class TInputImage> +void +KMeansAttributesLabelMapFilter<TInputImage> +::Compute() +{ + m_LabelMapToSampleListFilter->SetInputLabelMap(m_InputLabelMap); + m_LabelMapToSampleListFilter->Update(); + + typename ListSampleType::Pointer listSamples = m_LabelMapToSampleListFilter->GetOutputSampleList(); + typename TrainingListSampleType::Pointer trainingSamples = m_LabelMapToSampleListFilter->GetOutputTrainingSampleList(); + + // Build the Kd Tree + typename TreeGeneratorType::Pointer kdTreeGenerator = TreeGeneratorType::New(); + kdTreeGenerator->SetSample(listSamples); + kdTreeGenerator->SetBucketSize(100); + kdTreeGenerator->Update(); + // Randomly pick the initial means among the classes + unsigned int sampleSize = listSamples->GetMeasurementVector(0).Size(); + const unsigned int OneClassNbCentroids = 10; + unsigned int numberOfCentroids = (m_NumberOfClasses == 1 ? OneClassNbCentroids : m_NumberOfClasses); + typename EstimatorType::ParametersType initialMeans(sampleSize * m_NumberOfClasses); + initialMeans.Fill(0.); + + if (m_NumberOfClasses > 1) + { + // For each class, choose a centroid as the first sample of this class encountered + for (ClassLabelType classLabel = 0; classLabel < m_NumberOfClasses; ++classLabel) + { + typename TrainingListSampleType::ConstIterator it; + // Iterate on the label list and stop when classLabel is found + // TODO: add random initialization ? + for (it = trainingSamples->Begin(); it != trainingSamples->End(); ++it) + { + if (it.GetMeasurementVector()[0] == classLabel) + break; + } + if (it == trainingSamples->End()) + { + itkExceptionMacro(<<"Unable to find a sample with class label "<< classLabel); + } + + typename ListSampleType::InstanceIdentifier identifier = it.GetInstanceIdentifier(); + const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier); + for (unsigned int i = 0; i < centroid.Size(); ++i) + { + initialMeans[classLabel * sampleSize + i] = centroid[i]; + } + } + } + else + { + typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType; + RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::New(); + unsigned int nbLabelObjects = listSamples->Size(); + + // Choose arbitrarily OneClassNbCentroids centroids among all available LabelObject + for (unsigned int centroidId = 0; centroidId < numberOfCentroids; ++centroidId) + { + typename ListSampleType::InstanceIdentifier identifier = randomGenerator->GetIntegerVariate(nbLabelObjects - 1); + const typename ListSampleType::MeasurementVectorType& centroid = listSamples->GetMeasurementVector(identifier); + for (unsigned int i = 0; i < centroid.Size(); ++i) + { + initialMeans[centroidId * sampleSize + i] = centroid[i]; + } + } + } + + // Run the KMeans algorithm + // Do KMeans estimation + typename EstimatorType::Pointer estimator = EstimatorType::New(); + estimator->SetParameters(initialMeans); + estimator->SetKdTree(kdTreeGenerator->GetOutput()); + estimator->SetMaximumIteration(10000); + estimator->SetCentroidPositionChangesThreshold(0.00001); + estimator->StartOptimization(); + + // Retrieve final centroids + m_Centroids.clear(); + + for(unsigned int cId = 0; cId < numberOfCentroids; ++cId) + { + VectorType newCenter(sampleSize); + for(unsigned int i = 0; i < sampleSize; ++i) + { + newCenter[i] = estimator->GetParameters()[cId * sampleSize + i]; + } + m_Centroids.push_back(newCenter); + } +} + +}// end namespace otb +#endif diff --git a/Testing/Code/OBIA/CMakeLists.txt b/Testing/Code/OBIA/CMakeLists.txt index c4ecc21a0e9a2bd19c65ee94637ce1bd52a1fbfb..4920059a0711b5a5b3e907d21a92978ac1216f73 100644 --- a/Testing/Code/OBIA/CMakeLists.txt +++ b/Testing/Code/OBIA/CMakeLists.txt @@ -42,7 +42,7 @@ ADD_TEST(obTvImageToLabelMapWithAttributesFilter ${OBIA_TESTS1} ${INPUTDATA}/cala_labelled.tif) ADD_TEST(obTuLabelMapSourceNew ${OBIA_TESTS1} - otbLabelMapSourceNew) + otbLabelMapSourceNew) ADD_TEST(obTvLabelMapToVectorDataFilter ${OBIA_TESTS1} otbLabelMapToVectorDataFilter @@ -54,7 +54,6 @@ ADD_TEST(obTvVectorDataToLabelMapFilter ${OBIA_TESTS1} ${INPUTDATA}/vectorIOexample_gis_to_vec.shp ${TEMP}/vectordataToLabelMap.png) - ADD_TEST(obTuLabelMapToSampleListFilterNew ${OBIA_TESTS1} otbLabelMapToSampleListFilterNew) @@ -114,6 +113,14 @@ ADD_TEST(obTvNormalizeAttributesLabelMapFilter ${OBIA_TESTS1} ${INPUTDATA}/cala_labelled.tif ${TEMP}/obTvNormalizeAttributesLabelMapFilter.txt) +ADD_TEST(obTuKMeansAttributesLabelMapFilterNew ${OBIA_TESTS1} + otbKMeansAttributesLabelMapFilterNew) + +ADD_TEST(obTvKMeansAttributesLabelMapFilter ${OBIA_TESTS1} + otbKMeansAttributesLabelMapFilter + ${INPUTDATA}/calanques.tif + ${INPUTDATA}/cala_labelled.tif + ${TEMP}/obTvKMeansAttributesLabelMapFilter.txt) ADD_TEST(obTuRadiometricAttributesLabelMapFilterNew ${OBIA_TESTS1} otbRadiometricAttributesLabelMapFilterNew @@ -170,6 +177,7 @@ otbLabelObjectMapVectorizer.cxx otbLabelObjectToPolygonFunctorNew.cxx otbMinMaxAttributesLabelMapFilter.cxx otbNormalizeAttributesLabelMapFilter.cxx +otbKMeansAttributesLabelMapFilter.cxx otbRadiometricAttributesLabelMapFilterNew.cxx otbShapeAttributesLabelMapFilterNew.cxx otbStatisticsAttributesLabelMapFilterNew.cxx diff --git a/Testing/Code/OBIA/otbKMeansAttributesLabelMapFilter.cxx b/Testing/Code/OBIA/otbKMeansAttributesLabelMapFilter.cxx new file mode 100644 index 0000000000000000000000000000000000000000..b1bc8f6fcc4a21de1d102207a06ec0b8a56dd5d2 --- /dev/null +++ b/Testing/Code/OBIA/otbKMeansAttributesLabelMapFilter.cxx @@ -0,0 +1,113 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ + +#include "otbImageFileReader.h" + +#include <fstream> +#include <iostream> + +#include "otbImage.h" +#include "otbVectorImage.h" +#include "otbAttributesMapLabelObjectWithClassLabel.h" +#include "itkLabelMap.h" +#include "itkLabelImageToLabelMapFilter.h" +#include "otbImageToLabelMapWithAttributesFilter.h" +#include "otbKMeansAttributesLabelMapFilter.h" + +const unsigned int Dimension = 2; +typedef unsigned short LabelType; +typedef double PixelType; + +typedef otb::AttributesMapLabelObjectWithClassLabel<LabelType, Dimension, double, LabelType> + LabelObjectType; +typedef itk::LabelMap<LabelObjectType> LabelMapType; +typedef otb::VectorImage<PixelType, Dimension> VectorImageType; +typedef otb::Image<unsigned int,2> LabeledImageType; + +typedef otb::ImageFileReader<VectorImageType> ReaderType; +typedef otb::ImageFileReader<LabeledImageType> LabeledReaderType; +typedef itk::LabelImageToLabelMapFilter<LabeledImageType,LabelMapType> LabelMapFilterType; +typedef otb::ShapeAttributesLabelMapFilter<LabelMapType> ShapeFilterType; +typedef otb::KMeansAttributesLabelMapFilter<LabelMapType> KMeansAttributesLabelMapFilterType; + +int otbKMeansAttributesLabelMapFilterNew(int argc, char * argv[]) +{ + KMeansAttributesLabelMapFilterType::Pointer radiometricLabelMapFilter = KMeansAttributesLabelMapFilterType::New(); + return EXIT_SUCCESS; +} + +int otbKMeansAttributesLabelMapFilter(int argc, char * argv[]) +{ + const char * infname = argv[1]; + const char * lfname = argv[2]; + const char * outfname = argv[3]; + + // SmartPointer instanciation + ReaderType::Pointer reader = ReaderType::New(); + LabeledReaderType::Pointer labeledReader = LabeledReaderType::New(); + LabelMapFilterType::Pointer filter = LabelMapFilterType::New(); + ShapeFilterType::Pointer shapeFilter = ShapeFilterType::New(); + KMeansAttributesLabelMapFilterType::Pointer kmeansLabelMapFilter = KMeansAttributesLabelMapFilterType::New(); + + // Inputs + reader->SetFileName(infname); + reader->UpdateOutputInformation(); + labeledReader->SetFileName(lfname); + labeledReader->UpdateOutputInformation(); + + // Filter + filter->SetInput(labeledReader->GetOutput()); + filter->SetBackgroundValue(itk::NumericTraits<LabelType>::max()); + + shapeFilter->SetInput(filter->GetOutput()); + shapeFilter->Update(); + + // Labelize the objects + LabelMapType::Pointer labelMap = shapeFilter->GetOutput(); + + LabelMapType::LabelObjectContainerType& container = labelMap->GetLabelObjectContainer(); + LabelMapType::LabelObjectContainerType::const_iterator loIt = container.begin(); + unsigned int labelObjectID = 0; + for(loIt = container.begin(); loIt != container.end(); ++loIt ) + { + unsigned int classLabel = labelObjectID % 3; + loIt->second->SetClassLabel(classLabel); + ++labelObjectID; + } + + kmeansLabelMapFilter->SetInputLabelMap(labelMap); + std::vector<std::string> attributes = labelMap->GetLabelObject(0)->GetAvailableAttributes(); + std::vector<std::string>::const_iterator attrIt; + for (attrIt = attributes.begin(); attrIt != attributes.end(); ++attrIt) + { + kmeansLabelMapFilter->GetMeasurementFunctor().AddAttribute((*attrIt).c_str()); + } + + kmeansLabelMapFilter->SetNumberOfClasses(3); + kmeansLabelMapFilter->Compute(); + + std::ofstream outfile(outfname); + const KMeansAttributesLabelMapFilterType::CentroidsVectorType& centroids = kmeansLabelMapFilter->GetCentroids(); + for (unsigned int i = 0; i < centroids.size(); ++i) + { + outfile << "Centroid " << i << " : " << centroids[i] << std::endl; + } + outfile.close(); + + return EXIT_SUCCESS; +} diff --git a/Testing/Code/OBIA/otbOBIATests1.cxx b/Testing/Code/OBIA/otbOBIATests1.cxx index fdae20c10c43205a281fac507f849d45431d04e6..993822ec7868a47ad06b4b2f007179e8c0ccabfd 100644 --- a/Testing/Code/OBIA/otbOBIATests1.cxx +++ b/Testing/Code/OBIA/otbOBIATests1.cxx @@ -45,6 +45,8 @@ REGISTER_TEST(otbMinMaxAttributesLabelMapFilterNew); REGISTER_TEST(otbMinMaxAttributesLabelMapFilter); REGISTER_TEST(otbNormalizeAttributesLabelMapFilterNew); REGISTER_TEST(otbNormalizeAttributesLabelMapFilter); +REGISTER_TEST(otbKMeansAttributesLabelMapFilterNew); +REGISTER_TEST(otbKMeansAttributesLabelMapFilter); REGISTER_TEST(otbRadiometricAttributesLabelMapFilterNew); REGISTER_TEST(otbShapeAttributesLabelMapFilterNew); REGISTER_TEST(otbStatisticsAttributesLabelMapFilterNew);