Skip to content
Snippets Groups Projects
Commit 7304f4b5 authored by Jonathan Guinet's avatar Jonathan Guinet
Browse files

ENH: port to new framework

parent cc10e722
No related branches found
No related tags found
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbVectorImage.h"
#include "otbImage.h"
#include "otbImageFileReader.h"
#include "otbStreamingImageFileWriter.h"
#include "otbImageFileWriter.h"
#include "otbCommandLineArgumentParser.h"
#include "itkEuclideanDistance.h"
#include "itkImageRegionSplitter.h"
#include "otbStreamingTraits.h"
......@@ -13,283 +30,322 @@
#include "itkWeightedCentroidKdTreeGenerator.h"
#include "itkKdTreeBasedKmeansEstimator.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
#include "itkCastImageFilter.h"
#include "otbMultiToMonoChannelExtractROI.h"
int main(int argc, char * argv[])
namespace otb
{
namespace Wrapper
{
// Parse command line parameters
typedef otb::CommandLineArgumentParser ParserType;
ParserType::Pointer parser = ParserType::New();
parser->SetProgramDescription("Unsupervised KMeans image classification");
parser->AddInputImage();
parser->AddOutputImage();
parser->AddOption("--ValidityMask","Validity mask","-vm", 1, true);
parser->AddOption("--MaxTrainingSetSize","Size of the training set","-ts", 1, true);
parser->AddOption("--TrainingSetProbability","Probability for a sample to be selected in the training set","-tp", 1, true);
parser->AddOption("--NumberOfClasses","Number of classes","-nc", 1, true);
parser->AddOption("--InitialCentroidProbability","Probability for a pixel to be selected as an initial class centroid","-cp", 1, true);
parser->AddOption("--StreamingNumberOfLines","Number of lines for each streaming block","-sl", 1, true);
typedef otb::CommandLineArgumentParseResult ParserResultType;
ParserResultType::Pointer parseResult = ParserResultType::New();
try
{
parser->ParseCommandLine(argc, argv, parseResult);
}
catch ( itk::ExceptionObject & err )
{
std::string descriptionException = err.GetDescription();
if (descriptionException.find("ParseCommandLine(): Help Parser") != std::string::npos)
{
return EXIT_SUCCESS;
}
if (descriptionException.find("ParseCommandLine(): Version Parser") != std::string::npos)
{
return EXIT_SUCCESS;
}
return EXIT_FAILURE;
}
// initiating random number generation
itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randomGen
= itk::Statistics::MersenneTwisterRandomVariateGenerator::New();
typedef otb::Image<FloatVectorImageType::InternalPixelType, 2> ImageReaderType;
std::string infname = parseResult->GetInputImage();
std::string maskfname = parseResult->GetParameterString("--ValidityMask", 0);
std::string outfname = parseResult->GetOutputImage();
const unsigned int nbsamples = parseResult->GetParameterUInt("--MaxTrainingSetSize");
const double trainingProb =parseResult->GetParameterDouble("--TrainingSetProbability");
const double initProb = parseResult->GetParameterDouble("--InitialCentroidProbability");
const unsigned int nbLinesForStreaming = parseResult->GetParameterUInt("--StreamingNumberOfLines");
const unsigned int nb_classes = parseResult->GetParameterUInt("--NumberOfClasses");
typedef UInt8ImageType LabeledImageType;
typedef ImageReaderType::PixelType PixelType;
typedef unsigned short PixelType;
typedef unsigned short LabeledPixelType;
typedef itk::FixedArray<PixelType, 108> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
typedef itk::CastImageFilter<FloatImageListType, FloatImageType> CastMaskFilterType;
typedef otb::MultiToMonoChannelExtractROI<FloatVectorImageType::InternalPixelType,LabeledImageType::InternalPixelType > ExtractorType;
typedef otb::VectorImage<PixelType, 2> ImageType;
typedef otb::Image<LabeledPixelType, 2> LabeledImageType;
typedef otb::ImageFileReader<ImageType> ImageReaderType;
typedef otb::ImageFileReader<LabeledImageType> LabeledImageReaderType;
typedef otb::StreamingImageFileWriter<LabeledImageType> WriterType;
typedef otb::StreamingTraits<FloatVectorImageType> StreamingTraitsType;
typedef itk::ImageRegionSplitter<2> SplitterType;
typedef ImageReaderType::RegionType RegionType;
typedef itk::FixedArray<PixelType, 108> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType;
typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType;
typedef otb::StreamingTraits<ImageType> StreamingTraitsType;
typedef itk::ImageRegionSplitter<2> SplitterType;
typedef ImageType::RegionType RegionType;
typedef otb::KMeansImageClassificationFilter<FloatVectorImageType, LabeledImageType, 108> ClassificationFilterType;
typedef itk::ImageRegionConstIterator<ImageType> IteratorType;
typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType;
class KMeansClassification: public Application
{
public:
/** Standard class typedefs. */
typedef KMeansClassification Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef otb::KMeansImageClassificationFilter<ImageType, LabeledImageType, 108> ClassificationFilterType;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(KMeansClassification, otb::Application);
ImageReaderType::Pointer reader = ImageReaderType::New();
LabeledImageReaderType::Pointer maskReader = LabeledImageReaderType::New();
private:
KMeansClassification()
{
SetName("KMeansClassification");
SetDescription("Unsupervised KMeans image classification.");
}
reader->SetFileName(infname);
maskReader->SetFileName(maskfname);
virtual ~KMeansClassification()
{
}
/*******************************************/
/* Sampling data */
/*******************************************/
std::cout<<std::endl;
std::cout<<"-- SAMPLING DATA --"<<std::endl;
std::cout<<std::endl;
void DoCreateParameters()
{
// Update input images information
reader->GenerateOutputInformation();
maskReader->GenerateOutputInformation();
AddParameter(ParameterType_InputImage, "in", "Input Image");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
AddParameter(ParameterType_InputImage, "vm", "Validity Mask");
AddParameter(ParameterType_Int, "ts", "Size of the training set");
SetParameterInt("ts", 100);
AddParameter(ParameterType_Float, "tp", "Probability for a sample to be selected in the training set");
SetParameterFloat("tp", 0.5);
AddParameter(ParameterType_Int, "nc", "Number of classes");
SetParameterInt("nc", 3);
AddParameter(ParameterType_Float, "cp", "Probability for a pixel to be selected as an initial class centroid");
SetParameterFloat("cp", 0.8);
AddParameter(ParameterType_Int, "sl", "Number of lines for each streaming block");
SetParameterInt("sl", 1000);
}
if (reader->GetOutput()->GetLargestPossibleRegion()
!= maskReader->GetOutput()->GetLargestPossibleRegion()
)
void DoUpdateParameters()
{
std::cerr<<"Mask image and input image have different sizes."<<std::endl;
return EXIT_FAILURE;
// Nothing to do here : all parameters are independent
}
RegionType largestRegion = reader->GetOutput()->GetLargestPossibleRegion();
// Setting up local streaming capabilities
SplitterType::Pointer splitter = SplitterType::New();
unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(reader->GetOutput(),
largestRegion,
splitter,
otb::SET_BUFFER_NUMBER_OF_LINES,
0, 0, nbLinesForStreaming);
std::cout<<"The images will be streamed into "<<numberOfStreamDivisions<<" parts."<<std::endl;
// Training sample lists
ListSampleType::Pointer sampleList = ListSampleType::New();
EstimatorType::ParametersType initialMeans(108*nb_classes);
initialMeans.Fill(0);
unsigned int init_means_index = 0;
// Sample dimension and max dimension
unsigned int maxDimension = SampleType::Dimension;
unsigned int sampleSize = std::min(reader->GetOutput()->GetNumberOfComponentsPerPixel(), maxDimension);
unsigned int totalSamples = 0;
std::cout<<"Sample max possible dimension: "<<maxDimension<<std::endl;
std::cout<<"The following sample size will be used: "<<sampleSize<<std::endl;
std::cout<<std::endl;
// local streaming variables
unsigned int piece = 0;
RegionType streamingRegion;
while ((totalSamples<nbsamples)&&(init_means_index<108*nb_classes))
void DoExecute()
{
double random = randomGen->GetVariateWithClosedRange();
piece = static_cast<unsigned int>(random * numberOfStreamDivisions);
GetLogger()->Debug("Entering DoExecute\n");
// initiating random number generation
itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer
randomGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New();
m_InImage = GetParameterImage("in");
m_Extractor = ExtractorType::New();
m_Extractor->SetInput(GetParameterImage("vm"));
m_Extractor->SetChannel(1);
m_Extractor->UpdateOutputInformation();
LabeledImageType::Pointer maskImage = m_Extractor->GetOutput();
std::ostringstream message("");
const unsigned int nbsamples = GetParameterInt("ts");
const double trainingProb = GetParameterFloat("tp");
const double initProb = GetParameterFloat("cp");
const unsigned int nb_classes = GetParameterInt("nc");
const unsigned int nbLinesForStreaming = GetParameterInt("sl");
/*******************************************/
/* Sampling data */
/*******************************************/
GetLogger()->Info("-- SAMPLING DATA --");
// Update input images information
m_InImage->UpdateOutputInformation();
maskImage->UpdateOutputInformation();
if (m_InImage->GetLargestPossibleRegion() != maskImage->GetLargestPossibleRegion())
{
GetLogger()->Error("Mask image and input image have different sizes.");
}
streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion);
RegionType largestRegion = m_InImage->GetLargestPossibleRegion();
// Setting up local streaming capabilities
SplitterType::Pointer splitter = SplitterType::New();
unsigned int
numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(
m_InImage,
largestRegion,
splitter,
otb::SET_BUFFER_NUMBER_OF_LINES,
0, 0, nbLinesForStreaming);
message.clear();
message << "The images will be streamed into " << numberOfStreamDivisions << " parts.";
GetLogger()->Info(message.str());
// Training sample lists
ListSampleType::Pointer sampleList = ListSampleType::New();
EstimatorType::ParametersType initialMeans(108 * nb_classes);
initialMeans.Fill(0);
unsigned int init_means_index = 0;
// Sample dimension and max dimension
unsigned int maxDimension = SampleType::Dimension;
unsigned int sampleSize = std::min(m_InImage->GetNumberOfComponentsPerPixel(), maxDimension);
unsigned int totalSamples = 0;
message.clear();
message << "Sample max possible dimension: " << maxDimension << std::endl;
GetLogger()->Info(message.str());
message.clear();
message << "The following sample size will be used: " << sampleSize << std::endl;
GetLogger()->Info(message.str());
// local streaming variables
unsigned int piece = 0;
RegionType streamingRegion;
while ((totalSamples < nbsamples) && (init_means_index < 108 * nb_classes))
{
double random = randomGen->GetVariateWithClosedRange();
piece = static_cast<unsigned int> (random * numberOfStreamDivisions);
std::cout<<"Processing region: "<<streamingRegion<<std::endl;
streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion);
reader->GetOutput()->SetRequestedRegion(streamingRegion);
reader->GetOutput()->PropagateRequestedRegion();
reader->GetOutput()->UpdateOutputData();
message.clear();
message << "Processing region: " << streamingRegion << std::endl;
GetLogger()->Info(message.str());
maskReader->GetOutput()->SetRequestedRegion(streamingRegion);
maskReader->GetOutput()->PropagateRequestedRegion();
maskReader->GetOutput()->UpdateOutputData();
m_InImage->SetRequestedRegion(streamingRegion);
m_InImage->PropagateRequestedRegion();
m_InImage->UpdateOutputData();
IteratorType it(reader->GetOutput(), streamingRegion);
LabeledIteratorType maskIt(maskReader->GetOutput(), streamingRegion);
maskImage->SetRequestedRegion(streamingRegion);
maskImage->PropagateRequestedRegion();
maskImage->UpdateOutputData();
it.GoToBegin();
maskIt.GoToBegin();
IteratorType it(m_InImage, streamingRegion);
LabeledIteratorType m_MaskIt(maskImage, streamingRegion);
unsigned int localNbSamples=0;
it.GoToBegin();
m_MaskIt.GoToBegin();
// Loop on the image
while (!it.IsAtEnd()&&!maskIt.IsAtEnd()&&(totalSamples<nbsamples)&&(init_means_index<108*nb_classes))
{
// If the current pixel is labeled
if (maskIt.Get()>0)
{
if ((rand()<trainingProb*RAND_MAX))
{
SampleType newSample;
unsigned int localNbSamples = 0;
// build the sample
newSample.Fill(0);
for (unsigned int i = 0; i<sampleSize; ++i)
// Loop on the image
while (!it.IsAtEnd() && !m_MaskIt.IsAtEnd() && (totalSamples < nbsamples)
&& (init_means_index < (108 * nb_classes)))
{
// If the current pixel is labeled
if (m_MaskIt.Get() > 0)
{
newSample[i]=it.Get()[i];
if ((rand() < trainingProb * RAND_MAX))
{
SampleType newSample;
// build the sample
newSample.Fill(0);
for (unsigned int i = 0; i < sampleSize; ++i)
{
newSample[i] = it.Get()[i];
}
// Update the the sample lists
sampleList->PushBack(newSample);
++totalSamples;
++localNbSamples;
}
else
if ((init_means_index < 108 * nb_classes) && (rand() < initProb * RAND_MAX))
{
for (unsigned int i = 0; i < sampleSize; ++i)
{
initialMeans[init_means_index + i] = it.Get()[i];
}
init_means_index += 108;
}
}
// Update the the sample lists
sampleList->PushBack(newSample);
++totalSamples;
++localNbSamples;
++it;
++m_MaskIt;
}
else if ((init_means_index<108*nb_classes)&&(rand()<initProb*RAND_MAX))
message.clear();
message << localNbSamples << " samples added to the training set." << std::endl;
GetLogger()->Info(message.str());
}
message.clear();
message << "The final training set contains " << totalSamples << " samples." << std::endl;
GetLogger()->Info(message.str());
message.clear();
message << "Data sampling completed." << std::endl;
GetLogger()->Info(message.str());
/*******************************************/
/* Learning */
/*******************************************/
message.clear();
message << "-- LEARNING --" << std::endl;
message << "Initial centroids are: " << std::endl;
GetLogger()->Info(message.str());
message.clear();
for (unsigned int i = 0; i < nb_classes; ++i)
{
message << "Class " << i << ": ";
for (unsigned int j = 0; j < sampleSize; ++j)
{
for (unsigned int i = 0; i<sampleSize; ++i)
{
initialMeans[init_means_index+i]=it.Get()[i];
}
init_means_index += 108;
message << initialMeans[i * 108 + j] << "\t";
}
message << std::endl;
}
++it;
++maskIt;
}
std::cout<<localNbSamples<<" samples added to the training set."<<std::endl;
std::cout<<std::endl;
}
message << std::endl;
std::cout<<"The final training set contains "<<totalSamples<<" samples."<<std::endl;
message.clear();
message << "Starting optimization." << std::endl;
message << std::endl;
GetLogger()->Info(message.str());
EstimatorType::Pointer estimator = EstimatorType::New();
std::cout<<std::endl;
std::cout<<"Data sampling completed."<<std::endl;
std::cout<<std::endl;
TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
treeGenerator->SetSample(sampleList);
treeGenerator->SetBucketSize(100);
treeGenerator->Update();
/*******************************************/
/* Learning */
/*******************************************/
estimator->SetParameters(initialMeans);
estimator->SetKdTree(treeGenerator->GetOutput());
estimator->SetMaximumIteration(100000000);
estimator->SetCentroidPositionChangesThreshold(0.001);
estimator->StartOptimization();
std::cout<<"-- LEARNING --"<<std::endl;
std::cout<<std::endl;
EstimatorType::ParametersType estimatedMeans = estimator->GetParameters();
message.clear();
message << "Optimization completed." << std::endl;
message << std::endl;
message << "Estimated centroids are: " << std::endl;
std::cout<<"Initial centroids are: "<<std::endl;
for (unsigned int i = 0; i < nb_classes; ++i)
{
message << "Class " << i << ": ";
for (unsigned int j = 0; j < sampleSize; ++j)
{
message << estimatedMeans[i * 108 + j] << "\t";
}
message << std::endl;
}
for (unsigned int i=0; i<nb_classes; ++i)
{
std::cout<<"Class "<<i<<": ";
for (unsigned int j = 0; j<sampleSize; ++j)
{
std::cout<<initialMeans[i*108+j]<<"\t";
}
std::cout<<std::endl;
}
std::cout<<std::endl;
message << std::endl;
message << "Learning completed." << std::endl;
message << std::endl;
GetLogger()->Info(message.str());
std::cout<<"Starting optimization."<<std::endl;
std::cout<<std::endl;
EstimatorType::Pointer estimator = EstimatorType::New();
/*******************************************/
/* Classification */
/*******************************************/
message.clear();
message << "-- CLASSIFICATION --" << std::endl;
message << std::endl;
GetLogger()->Info(message.str());
TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
treeGenerator->SetSample(sampleList);
treeGenerator->SetBucketSize(100);
treeGenerator->Update();
m_Classifier = ClassificationFilterType::New();
estimator->SetParameters(initialMeans);
estimator->SetKdTree(treeGenerator->GetOutput());
estimator->SetMaximumIteration(100000000);
estimator->SetCentroidPositionChangesThreshold(0.001);
estimator->StartOptimization();
m_Classifier->SetInput(m_InImage);
m_Classifier->SetInputMask(maskImage);
EstimatorType::ParametersType estimatedMeans = estimator->GetParameters();
m_Classifier->SetCentroids(estimator->GetParameters());
std::cout<<"Optimization completed."<<std::endl;
std::cout<<std::endl;
std::cout<<"Estimated centroids are: "<<std::endl;
SetParameterOutputImage<LabeledImageType> ("out", m_Classifier->GetOutput());
for (unsigned int i=0; i<nb_classes; ++i)
{
std::cout<<"Class "<<i<<": ";
for (unsigned int j = 0; j<sampleSize; ++j)
{
std::cout<<estimatedMeans[i*108+j]<<"\t";
}
std::cout<<std::endl;
}
std::cout<<std::endl;
std::cout<<"Learning completed."<<std::endl;
std::cout<<std::endl;
ExtractorType::Pointer m_Extractor;
ClassificationFilterType::Pointer m_Classifier;
FloatVectorImageType::Pointer m_InImage;
};
/*******************************************/
/* Classification */
/*******************************************/
}
}
std::cout<<"-- CLASSIFICATION --"<<std::endl;
std::cout<<std::endl;
ClassificationFilterType::Pointer classifier = ClassificationFilterType::New();
classifier->SetInput(reader->GetOutput());
classifier->SetInputMask(maskReader->GetOutput());
classifier->SetCentroids(estimator->GetParameters());
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)
WriterType::Pointer writer = WriterType::New();
writer->SetFileName(outfname);
writer->SetInput(classifier->GetOutput());
writer->SetNumberOfDivisionsStrippedStreaming(numberOfStreamDivisions);
writer->Update();
std::cout<<"Classification completed."<<std::endl;
std::cout<<std::endl;
return EXIT_SUCCESS;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment