diff --git a/Applications/Classification/otbKMeansClassification.cxx b/Applications/Classification/otbKMeansClassification.cxx index f7024472c27fd1847cb24cf3a071ebf7ab247ae2..850b84b0bbdc86ac3f4d026749179fa6d3070764 100644 --- a/Applications/Classification/otbKMeansClassification.cxx +++ b/Applications/Classification/otbKMeansClassification.cxx @@ -1,9 +1,26 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ + +#include "otbWrapperApplication.h" +#include "otbWrapperApplicationFactory.h" + #include "otbVectorImage.h" #include "otbImage.h" -#include "otbImageFileReader.h" -#include "otbStreamingImageFileWriter.h" -#include "otbImageFileWriter.h" -#include "otbCommandLineArgumentParser.h" #include "itkEuclideanDistance.h" #include "itkImageRegionSplitter.h" #include "otbStreamingTraits.h" @@ -13,283 +30,322 @@ #include "itkWeightedCentroidKdTreeGenerator.h" #include "itkKdTreeBasedKmeansEstimator.h" #include "itkMersenneTwisterRandomVariateGenerator.h" +#include "itkCastImageFilter.h" +#include "otbMultiToMonoChannelExtractROI.h" -int main(int argc, char * argv[]) +namespace otb +{ +namespace Wrapper { - // Parse command line parameters - typedef otb::CommandLineArgumentParser ParserType; - ParserType::Pointer parser = ParserType::New(); - - parser->SetProgramDescription("Unsupervised KMeans image classification"); - parser->AddInputImage(); - parser->AddOutputImage(); - parser->AddOption("--ValidityMask","Validity mask","-vm", 1, true); - parser->AddOption("--MaxTrainingSetSize","Size of the training set","-ts", 1, true); - parser->AddOption("--TrainingSetProbability","Probability for a sample to be selected in the training set","-tp", 1, true); - parser->AddOption("--NumberOfClasses","Number of classes","-nc", 1, true); - parser->AddOption("--InitialCentroidProbability","Probability for a pixel to be selected as an initial class centroid","-cp", 1, true); - parser->AddOption("--StreamingNumberOfLines","Number of lines for each streaming block","-sl", 1, true); - - - typedef otb::CommandLineArgumentParseResult ParserResultType; - ParserResultType::Pointer parseResult = ParserResultType::New(); - - try - { - parser->ParseCommandLine(argc, argv, parseResult); - } - catch ( itk::ExceptionObject & err ) - { - std::string descriptionException = err.GetDescription(); - if (descriptionException.find("ParseCommandLine(): Help Parser") != std::string::npos) - { - return EXIT_SUCCESS; - } - if (descriptionException.find("ParseCommandLine(): Version Parser") != std::string::npos) - { - return EXIT_SUCCESS; - } - return EXIT_FAILURE; - } - - // initiating random number generation - itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randomGen - = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); - +typedef otb::Image<FloatVectorImageType::InternalPixelType, 2> ImageReaderType; - std::string infname = parseResult->GetInputImage(); - std::string maskfname = parseResult->GetParameterString("--ValidityMask", 0); - std::string outfname = parseResult->GetOutputImage(); - const unsigned int nbsamples = parseResult->GetParameterUInt("--MaxTrainingSetSize"); - const double trainingProb =parseResult->GetParameterDouble("--TrainingSetProbability"); - const double initProb = parseResult->GetParameterDouble("--InitialCentroidProbability"); - const unsigned int nbLinesForStreaming = parseResult->GetParameterUInt("--StreamingNumberOfLines"); - const unsigned int nb_classes = parseResult->GetParameterUInt("--NumberOfClasses"); +typedef UInt8ImageType LabeledImageType; +typedef ImageReaderType::PixelType PixelType; - typedef unsigned short PixelType; - typedef unsigned short LabeledPixelType; +typedef itk::FixedArray<PixelType, 108> SampleType; +typedef itk::Statistics::ListSample<SampleType> ListSampleType; +typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType; +typedef TreeGeneratorType::KdTreeType TreeType; +typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType; +typedef itk::CastImageFilter<FloatImageListType, FloatImageType> CastMaskFilterType; +typedef otb::MultiToMonoChannelExtractROI<FloatVectorImageType::InternalPixelType,LabeledImageType::InternalPixelType > ExtractorType; - typedef otb::VectorImage<PixelType, 2> ImageType; - typedef otb::Image<LabeledPixelType, 2> LabeledImageType; - typedef otb::ImageFileReader<ImageType> ImageReaderType; - typedef otb::ImageFileReader<LabeledImageType> LabeledImageReaderType; - typedef otb::StreamingImageFileWriter<LabeledImageType> WriterType; +typedef otb::StreamingTraits<FloatVectorImageType> StreamingTraitsType; +typedef itk::ImageRegionSplitter<2> SplitterType; +typedef ImageReaderType::RegionType RegionType; - typedef itk::FixedArray<PixelType, 108> SampleType; - typedef itk::Statistics::ListSample<SampleType> ListSampleType; - typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType; - typedef TreeGeneratorType::KdTreeType TreeType; - typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType; +typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType; +typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType; - typedef otb::StreamingTraits<ImageType> StreamingTraitsType; - typedef itk::ImageRegionSplitter<2> SplitterType; - typedef ImageType::RegionType RegionType; +typedef otb::KMeansImageClassificationFilter<FloatVectorImageType, LabeledImageType, 108> ClassificationFilterType; - typedef itk::ImageRegionConstIterator<ImageType> IteratorType; - typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType; +class KMeansClassification: public Application +{ +public: + /** Standard class typedefs. */ + typedef KMeansClassification Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; - typedef otb::KMeansImageClassificationFilter<ImageType, LabeledImageType, 108> ClassificationFilterType; + /** Standard macro */ + itkNewMacro(Self); + itkTypeMacro(KMeansClassification, otb::Application); - ImageReaderType::Pointer reader = ImageReaderType::New(); - LabeledImageReaderType::Pointer maskReader = LabeledImageReaderType::New(); +private: + KMeansClassification() + { + SetName("KMeansClassification"); + SetDescription("Unsupervised KMeans image classification."); + } - reader->SetFileName(infname); - maskReader->SetFileName(maskfname); + virtual ~KMeansClassification() + { + } - /*******************************************/ - /* Sampling data */ - /*******************************************/ - std::cout<<std::endl; - std::cout<<"-- SAMPLING DATA --"<<std::endl; - std::cout<<std::endl; + void DoCreateParameters() + { - // Update input images information - reader->GenerateOutputInformation(); - maskReader->GenerateOutputInformation(); + AddParameter(ParameterType_InputImage, "in", "Input Image"); + AddParameter(ParameterType_OutputImage, "out", "Output Image"); + AddParameter(ParameterType_InputImage, "vm", "Validity Mask"); + AddParameter(ParameterType_Int, "ts", "Size of the training set"); + SetParameterInt("ts", 100); + AddParameter(ParameterType_Float, "tp", "Probability for a sample to be selected in the training set"); + SetParameterFloat("tp", 0.5); + AddParameter(ParameterType_Int, "nc", "Number of classes"); + SetParameterInt("nc", 3); + AddParameter(ParameterType_Float, "cp", "Probability for a pixel to be selected as an initial class centroid"); + SetParameterFloat("cp", 0.8); + AddParameter(ParameterType_Int, "sl", "Number of lines for each streaming block"); + SetParameterInt("sl", 1000); + } - if (reader->GetOutput()->GetLargestPossibleRegion() - != maskReader->GetOutput()->GetLargestPossibleRegion() - ) + void DoUpdateParameters() { - std::cerr<<"Mask image and input image have different sizes."<<std::endl; - return EXIT_FAILURE; + // Nothing to do here : all parameters are independent } - RegionType largestRegion = reader->GetOutput()->GetLargestPossibleRegion(); - - // Setting up local streaming capabilities - SplitterType::Pointer splitter = SplitterType::New(); - unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(reader->GetOutput(), - largestRegion, - splitter, - otb::SET_BUFFER_NUMBER_OF_LINES, - 0, 0, nbLinesForStreaming); - - std::cout<<"The images will be streamed into "<<numberOfStreamDivisions<<" parts."<<std::endl; - - // Training sample lists - ListSampleType::Pointer sampleList = ListSampleType::New(); - EstimatorType::ParametersType initialMeans(108*nb_classes); - initialMeans.Fill(0); - unsigned int init_means_index = 0; - - // Sample dimension and max dimension - unsigned int maxDimension = SampleType::Dimension; - unsigned int sampleSize = std::min(reader->GetOutput()->GetNumberOfComponentsPerPixel(), maxDimension); - unsigned int totalSamples = 0; - std::cout<<"Sample max possible dimension: "<<maxDimension<<std::endl; - std::cout<<"The following sample size will be used: "<<sampleSize<<std::endl; - std::cout<<std::endl; - // local streaming variables - unsigned int piece = 0; - RegionType streamingRegion; - - while ((totalSamples<nbsamples)&&(init_means_index<108*nb_classes)) + void DoExecute() { - double random = randomGen->GetVariateWithClosedRange(); - piece = static_cast<unsigned int>(random * numberOfStreamDivisions); + GetLogger()->Debug("Entering DoExecute\n"); + + // initiating random number generation + itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer + randomGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); + m_InImage = GetParameterImage("in"); + m_Extractor = ExtractorType::New(); + m_Extractor->SetInput(GetParameterImage("vm")); + m_Extractor->SetChannel(1); + m_Extractor->UpdateOutputInformation(); + LabeledImageType::Pointer maskImage = m_Extractor->GetOutput(); + + std::ostringstream message(""); + + const unsigned int nbsamples = GetParameterInt("ts"); + const double trainingProb = GetParameterFloat("tp"); + const double initProb = GetParameterFloat("cp"); + const unsigned int nb_classes = GetParameterInt("nc"); + const unsigned int nbLinesForStreaming = GetParameterInt("sl"); + + /*******************************************/ + /* Sampling data */ + /*******************************************/ + GetLogger()->Info("-- SAMPLING DATA --"); + + // Update input images information + m_InImage->UpdateOutputInformation(); + maskImage->UpdateOutputInformation(); + + if (m_InImage->GetLargestPossibleRegion() != maskImage->GetLargestPossibleRegion()) + { + GetLogger()->Error("Mask image and input image have different sizes."); + } - streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion); + RegionType largestRegion = m_InImage->GetLargestPossibleRegion(); + + // Setting up local streaming capabilities + SplitterType::Pointer splitter = SplitterType::New(); + unsigned int + numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions( + m_InImage, + largestRegion, + splitter, + otb::SET_BUFFER_NUMBER_OF_LINES, + 0, 0, nbLinesForStreaming); + + message.clear(); + message << "The images will be streamed into " << numberOfStreamDivisions << " parts."; + GetLogger()->Info(message.str()); + + // Training sample lists + ListSampleType::Pointer sampleList = ListSampleType::New(); + EstimatorType::ParametersType initialMeans(108 * nb_classes); + initialMeans.Fill(0); + unsigned int init_means_index = 0; + + // Sample dimension and max dimension + unsigned int maxDimension = SampleType::Dimension; + unsigned int sampleSize = std::min(m_InImage->GetNumberOfComponentsPerPixel(), maxDimension); + unsigned int totalSamples = 0; + + message.clear(); + message << "Sample max possible dimension: " << maxDimension << std::endl; + GetLogger()->Info(message.str()); + message.clear(); + message << "The following sample size will be used: " << sampleSize << std::endl; + GetLogger()->Info(message.str()); + // local streaming variables + unsigned int piece = 0; + RegionType streamingRegion; + + while ((totalSamples < nbsamples) && (init_means_index < 108 * nb_classes)) + { + double random = randomGen->GetVariateWithClosedRange(); + piece = static_cast<unsigned int> (random * numberOfStreamDivisions); - std::cout<<"Processing region: "<<streamingRegion<<std::endl; + streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion); - reader->GetOutput()->SetRequestedRegion(streamingRegion); - reader->GetOutput()->PropagateRequestedRegion(); - reader->GetOutput()->UpdateOutputData(); + message.clear(); + message << "Processing region: " << streamingRegion << std::endl; + GetLogger()->Info(message.str()); - maskReader->GetOutput()->SetRequestedRegion(streamingRegion); - maskReader->GetOutput()->PropagateRequestedRegion(); - maskReader->GetOutput()->UpdateOutputData(); + m_InImage->SetRequestedRegion(streamingRegion); + m_InImage->PropagateRequestedRegion(); + m_InImage->UpdateOutputData(); - IteratorType it(reader->GetOutput(), streamingRegion); - LabeledIteratorType maskIt(maskReader->GetOutput(), streamingRegion); + maskImage->SetRequestedRegion(streamingRegion); + maskImage->PropagateRequestedRegion(); + maskImage->UpdateOutputData(); - it.GoToBegin(); - maskIt.GoToBegin(); + IteratorType it(m_InImage, streamingRegion); + LabeledIteratorType m_MaskIt(maskImage, streamingRegion); - unsigned int localNbSamples=0; + it.GoToBegin(); + m_MaskIt.GoToBegin(); - // Loop on the image - while (!it.IsAtEnd()&&!maskIt.IsAtEnd()&&(totalSamples<nbsamples)&&(init_means_index<108*nb_classes)) - { - // If the current pixel is labeled - if (maskIt.Get()>0) - { - if ((rand()<trainingProb*RAND_MAX)) - { - SampleType newSample; + unsigned int localNbSamples = 0; - // build the sample - newSample.Fill(0); - for (unsigned int i = 0; i<sampleSize; ++i) + // Loop on the image + while (!it.IsAtEnd() && !m_MaskIt.IsAtEnd() && (totalSamples < nbsamples) + && (init_means_index < (108 * nb_classes))) + { + // If the current pixel is labeled + if (m_MaskIt.Get() > 0) { - newSample[i]=it.Get()[i]; + if ((rand() < trainingProb * RAND_MAX)) + { + SampleType newSample; + + // build the sample + newSample.Fill(0); + for (unsigned int i = 0; i < sampleSize; ++i) + { + newSample[i] = it.Get()[i]; + } + // Update the the sample lists + sampleList->PushBack(newSample); + ++totalSamples; + ++localNbSamples; + } + else + if ((init_means_index < 108 * nb_classes) && (rand() < initProb * RAND_MAX)) + { + for (unsigned int i = 0; i < sampleSize; ++i) + { + initialMeans[init_means_index + i] = it.Get()[i]; + } + init_means_index += 108; + } } - // Update the the sample lists - sampleList->PushBack(newSample); - ++totalSamples; - ++localNbSamples; + ++it; + ++m_MaskIt; } - else if ((init_means_index<108*nb_classes)&&(rand()<initProb*RAND_MAX)) + + message.clear(); + message << localNbSamples << " samples added to the training set." << std::endl; + GetLogger()->Info(message.str()); + + } + + message.clear(); + message << "The final training set contains " << totalSamples << " samples." << std::endl; + GetLogger()->Info(message.str()); + + message.clear(); + message << "Data sampling completed." << std::endl; + GetLogger()->Info(message.str()); + + /*******************************************/ + /* Learning */ + /*******************************************/ + message.clear(); + message << "-- LEARNING --" << std::endl; + message << "Initial centroids are: " << std::endl; + GetLogger()->Info(message.str()); + message.clear(); + for (unsigned int i = 0; i < nb_classes; ++i) + { + message << "Class " << i << ": "; + for (unsigned int j = 0; j < sampleSize; ++j) { - for (unsigned int i = 0; i<sampleSize; ++i) - { - initialMeans[init_means_index+i]=it.Get()[i]; - } - init_means_index += 108; + message << initialMeans[i * 108 + j] << "\t"; } + message << std::endl; } - ++it; - ++maskIt; - } - std::cout<<localNbSamples<<" samples added to the training set."<<std::endl; - std::cout<<std::endl; - } + message << std::endl; - std::cout<<"The final training set contains "<<totalSamples<<" samples."<<std::endl; + message.clear(); + message << "Starting optimization." << std::endl; + message << std::endl; + GetLogger()->Info(message.str()); + EstimatorType::Pointer estimator = EstimatorType::New(); - std::cout<<std::endl; - std::cout<<"Data sampling completed."<<std::endl; - std::cout<<std::endl; + TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New(); + treeGenerator->SetSample(sampleList); + treeGenerator->SetBucketSize(100); + treeGenerator->Update(); - /*******************************************/ - /* Learning */ - /*******************************************/ + estimator->SetParameters(initialMeans); + estimator->SetKdTree(treeGenerator->GetOutput()); + estimator->SetMaximumIteration(100000000); + estimator->SetCentroidPositionChangesThreshold(0.001); + estimator->StartOptimization(); - std::cout<<"-- LEARNING --"<<std::endl; - std::cout<<std::endl; + EstimatorType::ParametersType estimatedMeans = estimator->GetParameters(); + message.clear(); + message << "Optimization completed." << std::endl; + message << std::endl; + message << "Estimated centroids are: " << std::endl; - std::cout<<"Initial centroids are: "<<std::endl; + for (unsigned int i = 0; i < nb_classes; ++i) + { + message << "Class " << i << ": "; + for (unsigned int j = 0; j < sampleSize; ++j) + { + message << estimatedMeans[i * 108 + j] << "\t"; + } + message << std::endl; + } - for (unsigned int i=0; i<nb_classes; ++i) - { - std::cout<<"Class "<<i<<": "; - for (unsigned int j = 0; j<sampleSize; ++j) - { - std::cout<<initialMeans[i*108+j]<<"\t"; - } - std::cout<<std::endl; - } - std::cout<<std::endl; + message << std::endl; + message << "Learning completed." << std::endl; + message << std::endl; + GetLogger()->Info(message.str()); - std::cout<<"Starting optimization."<<std::endl; - std::cout<<std::endl; - EstimatorType::Pointer estimator = EstimatorType::New(); + /*******************************************/ + /* Classification */ + /*******************************************/ + message.clear(); + message << "-- CLASSIFICATION --" << std::endl; + message << std::endl; + GetLogger()->Info(message.str()); - TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New(); - treeGenerator->SetSample(sampleList); - treeGenerator->SetBucketSize(100); - treeGenerator->Update(); + m_Classifier = ClassificationFilterType::New(); - estimator->SetParameters(initialMeans); - estimator->SetKdTree(treeGenerator->GetOutput()); - estimator->SetMaximumIteration(100000000); - estimator->SetCentroidPositionChangesThreshold(0.001); - estimator->StartOptimization(); + m_Classifier->SetInput(m_InImage); + m_Classifier->SetInputMask(maskImage); - EstimatorType::ParametersType estimatedMeans = estimator->GetParameters(); + m_Classifier->SetCentroids(estimator->GetParameters()); - std::cout<<"Optimization completed."<<std::endl; - std::cout<<std::endl; - std::cout<<"Estimated centroids are: "<<std::endl; + SetParameterOutputImage<LabeledImageType> ("out", m_Classifier->GetOutput()); - for (unsigned int i=0; i<nb_classes; ++i) - { - std::cout<<"Class "<<i<<": "; - for (unsigned int j = 0; j<sampleSize; ++j) - { - std::cout<<estimatedMeans[i*108+j]<<"\t"; - } - std::cout<<std::endl; } - std::cout<<std::endl; - std::cout<<"Learning completed."<<std::endl; - std::cout<<std::endl; + ExtractorType::Pointer m_Extractor; + ClassificationFilterType::Pointer m_Classifier; + FloatVectorImageType::Pointer m_InImage; + +}; - /*******************************************/ - /* Classification */ - /*******************************************/ +} +} - std::cout<<"-- CLASSIFICATION --"<<std::endl; - std::cout<<std::endl; - ClassificationFilterType::Pointer classifier = ClassificationFilterType::New(); - classifier->SetInput(reader->GetOutput()); - classifier->SetInputMask(maskReader->GetOutput()); - classifier->SetCentroids(estimator->GetParameters()); +OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification) - WriterType::Pointer writer = WriterType::New(); - writer->SetFileName(outfname); - writer->SetInput(classifier->GetOutput()); - writer->SetNumberOfDivisionsStrippedStreaming(numberOfStreamDivisions); - writer->Update(); - std::cout<<"Classification completed."<<std::endl; - std::cout<<std::endl; - return EXIT_SUCCESS; -}