Skip to content
Snippets Groups Projects
Commit 70d4ebe2 authored by Jonathan Guinet's avatar Jonathan Guinet
Browse files

MRG

parents e50bc0fc 7304f4b5
Branches
Tags
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbVectorImage.h"
#include "otbImage.h"
#include "itkEuclideanDistance.h"
#include "itkImageRegionSplitter.h"
#include "otbStreamingTraits.h"
#include "otbKMeansImageClassificationFilter.h"
#include "itkImageRegionConstIterator.h"
#include "itkListSample.h"
#include "itkWeightedCentroidKdTreeGenerator.h"
#include "itkKdTreeBasedKmeansEstimator.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
#include "itkCastImageFilter.h"
#include "otbMultiToMonoChannelExtractROI.h"
namespace otb
{
namespace Wrapper
{
typedef otb::Image<FloatVectorImageType::InternalPixelType, 2> ImageReaderType;
typedef UInt8ImageType LabeledImageType;
typedef ImageReaderType::PixelType PixelType;
typedef itk::FixedArray<PixelType, 108> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
typedef itk::CastImageFilter<FloatImageListType, FloatImageType> CastMaskFilterType;
typedef otb::MultiToMonoChannelExtractROI<FloatVectorImageType::InternalPixelType,LabeledImageType::InternalPixelType > ExtractorType;
typedef otb::StreamingTraits<FloatVectorImageType> StreamingTraitsType;
typedef itk::ImageRegionSplitter<2> SplitterType;
typedef ImageReaderType::RegionType RegionType;
typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType;
typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType;
typedef otb::KMeansImageClassificationFilter<FloatVectorImageType, LabeledImageType, 108> ClassificationFilterType;
class KMeansClassification: public Application
{
public:
/** Standard class typedefs. */
typedef KMeansClassification Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(KMeansClassification, otb::Application);
private:
KMeansClassification()
{
SetName("KMeansClassification");
SetDescription("Unsupervised KMeans image classification.");
}
virtual ~KMeansClassification()
{
}
void DoCreateParameters()
{
AddParameter(ParameterType_InputImage, "in", "Input Image");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
AddParameter(ParameterType_InputImage, "vm", "Validity Mask");
AddParameter(ParameterType_Int, "ts", "Size of the training set");
SetParameterInt("ts", 100);
AddParameter(ParameterType_Float, "tp", "Probability for a sample to be selected in the training set");
SetParameterFloat("tp", 0.5);
AddParameter(ParameterType_Int, "nc", "Number of classes");
SetParameterInt("nc", 3);
AddParameter(ParameterType_Float, "cp", "Probability for a pixel to be selected as an initial class centroid");
SetParameterFloat("cp", 0.8);
AddParameter(ParameterType_Int, "sl", "Number of lines for each streaming block");
SetParameterInt("sl", 1000);
}
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
}
void DoExecute()
{
GetLogger()->Debug("Entering DoExecute\n");
// initiating random number generation
itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer
randomGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New();
m_InImage = GetParameterImage("in");
m_Extractor = ExtractorType::New();
m_Extractor->SetInput(GetParameterImage("vm"));
m_Extractor->SetChannel(1);
m_Extractor->UpdateOutputInformation();
LabeledImageType::Pointer maskImage = m_Extractor->GetOutput();
std::ostringstream message("");
const unsigned int nbsamples = GetParameterInt("ts");
const double trainingProb = GetParameterFloat("tp");
const double initProb = GetParameterFloat("cp");
const unsigned int nb_classes = GetParameterInt("nc");
const unsigned int nbLinesForStreaming = GetParameterInt("sl");
/*******************************************/
/* Sampling data */
/*******************************************/
GetLogger()->Info("-- SAMPLING DATA --");
// Update input images information
m_InImage->UpdateOutputInformation();
maskImage->UpdateOutputInformation();
if (m_InImage->GetLargestPossibleRegion() != maskImage->GetLargestPossibleRegion())
{
GetLogger()->Error("Mask image and input image have different sizes.");
}
RegionType largestRegion = m_InImage->GetLargestPossibleRegion();
// Setting up local streaming capabilities
SplitterType::Pointer splitter = SplitterType::New();
unsigned int
numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(
m_InImage,
largestRegion,
splitter,
otb::SET_BUFFER_NUMBER_OF_LINES,
0, 0, nbLinesForStreaming);
message.clear();
message << "The images will be streamed into " << numberOfStreamDivisions << " parts.";
GetLogger()->Info(message.str());
// Training sample lists
ListSampleType::Pointer sampleList = ListSampleType::New();
EstimatorType::ParametersType initialMeans(108 * nb_classes);
initialMeans.Fill(0);
unsigned int init_means_index = 0;
// Sample dimension and max dimension
unsigned int maxDimension = SampleType::Dimension;
unsigned int sampleSize = std::min(m_InImage->GetNumberOfComponentsPerPixel(), maxDimension);
unsigned int totalSamples = 0;
message.clear();
message << "Sample max possible dimension: " << maxDimension << std::endl;
GetLogger()->Info(message.str());
message.clear();
message << "The following sample size will be used: " << sampleSize << std::endl;
GetLogger()->Info(message.str());
// local streaming variables
unsigned int piece = 0;
RegionType streamingRegion;
while ((totalSamples < nbsamples) && (init_means_index < 108 * nb_classes))
{
double random = randomGen->GetVariateWithClosedRange();
piece = static_cast<unsigned int> (random * numberOfStreamDivisions);
streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion);
message.clear();
message << "Processing region: " << streamingRegion << std::endl;
GetLogger()->Info(message.str());
m_InImage->SetRequestedRegion(streamingRegion);
m_InImage->PropagateRequestedRegion();
m_InImage->UpdateOutputData();
maskImage->SetRequestedRegion(streamingRegion);
maskImage->PropagateRequestedRegion();
maskImage->UpdateOutputData();
IteratorType it(m_InImage, streamingRegion);
LabeledIteratorType m_MaskIt(maskImage, streamingRegion);
it.GoToBegin();
m_MaskIt.GoToBegin();
unsigned int localNbSamples = 0;
// Loop on the image
while (!it.IsAtEnd() && !m_MaskIt.IsAtEnd() && (totalSamples < nbsamples)
&& (init_means_index < (108 * nb_classes)))
{
// If the current pixel is labeled
if (m_MaskIt.Get() > 0)
{
if ((rand() < trainingProb * RAND_MAX))
{
SampleType newSample;
// build the sample
newSample.Fill(0);
for (unsigned int i = 0; i < sampleSize; ++i)
{
newSample[i] = it.Get()[i];
}
// Update the the sample lists
sampleList->PushBack(newSample);
++totalSamples;
++localNbSamples;
}
else
if ((init_means_index < 108 * nb_classes) && (rand() < initProb * RAND_MAX))
{
for (unsigned int i = 0; i < sampleSize; ++i)
{
initialMeans[init_means_index + i] = it.Get()[i];
}
init_means_index += 108;
}
}
++it;
++m_MaskIt;
}
message.clear();
message << localNbSamples << " samples added to the training set." << std::endl;
GetLogger()->Info(message.str());
}
message.clear();
message << "The final training set contains " << totalSamples << " samples." << std::endl;
GetLogger()->Info(message.str());
message.clear();
message << "Data sampling completed." << std::endl;
GetLogger()->Info(message.str());
/*******************************************/
/* Learning */
/*******************************************/
message.clear();
message << "-- LEARNING --" << std::endl;
message << "Initial centroids are: " << std::endl;
GetLogger()->Info(message.str());
message.clear();
for (unsigned int i = 0; i < nb_classes; ++i)
{
message << "Class " << i << ": ";
for (unsigned int j = 0; j < sampleSize; ++j)
{
message << initialMeans[i * 108 + j] << "\t";
}
message << std::endl;
}
message << std::endl;
message.clear();
message << "Starting optimization." << std::endl;
message << std::endl;
GetLogger()->Info(message.str());
EstimatorType::Pointer estimator = EstimatorType::New();
TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
treeGenerator->SetSample(sampleList);
treeGenerator->SetBucketSize(100);
treeGenerator->Update();
estimator->SetParameters(initialMeans);
estimator->SetKdTree(treeGenerator->GetOutput());
estimator->SetMaximumIteration(100000000);
estimator->SetCentroidPositionChangesThreshold(0.001);
estimator->StartOptimization();
EstimatorType::ParametersType estimatedMeans = estimator->GetParameters();
message.clear();
message << "Optimization completed." << std::endl;
message << std::endl;
message << "Estimated centroids are: " << std::endl;
for (unsigned int i = 0; i < nb_classes; ++i)
{
message << "Class " << i << ": ";
for (unsigned int j = 0; j < sampleSize; ++j)
{
message << estimatedMeans[i * 108 + j] << "\t";
}
message << std::endl;
}
message << std::endl;
message << "Learning completed." << std::endl;
message << std::endl;
GetLogger()->Info(message.str());
/*******************************************/
/* Classification */
/*******************************************/
message.clear();
message << "-- CLASSIFICATION --" << std::endl;
message << std::endl;
GetLogger()->Info(message.str());
m_Classifier = ClassificationFilterType::New();
m_Classifier->SetInput(m_InImage);
m_Classifier->SetInputMask(maskImage);
m_Classifier->SetCentroids(estimator->GetParameters());
SetParameterOutputImage<LabeledImageType> ("out", m_Classifier->GetOutput());
}
ExtractorType::Pointer m_Extractor;
ClassificationFilterType::Pointer m_Classifier;
FloatVectorImageType::Pointer m_InImage;
};
}
}
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment