Skip to content
Snippets Groups Projects
Commit 78ef69b9 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: adapt SOMClassification to new framework

parent c62b43de
Branches
Tags
No related merge requests found
#include "otbVectorImage.h"
#include "otbImage.h"
#include "otbImageFileReader.h"
#include "otbStreamingImageFileWriter.h"
#include "otbImageFileWriter.h"
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbSOMMap.h"
#include "otbSOM.h"
#include "otbSOMImageClassificationFilter.h"
#include "otbCommandLineArgumentParser.h"
#include "itkEuclideanDistance.h"
#include "itkImageRegionSplitter.h"
#include "otbStreamingTraits.h"
#include "itkImageRegionConstIterator.h"
#include "itkVariableSizeMatrix.h"
#include "itkListSample.h"
#include "itkVariableLengthVector.h"
// #include "itkListSample.h"
#include "itkImageRandomNonRepeatingConstIteratorWithIndex.h"
//#include "itkMersenneTwisterRandomVariateGenerator.h"
int main(int argc, char * argv[])
namespace otb
{
namespace Wrapper
{
// Parse command line parameters
typedef otb::CommandLineArgumentParser ParserType;
ParserType::Pointer parser = ParserType::New();
parser->SetProgramDescription("Unsupervised Self Organizing Map image classification");
parser->AddInputImage();
parser->AddOutputImage();
parser->AddOption("--ValidityMask","Validity mask","-vm", 1, true);
parser->AddOption("--MaxTrainingSetSize","Size of the training set","-ts", 1, true);
parser->AddOption("--TrainingSetProbability","Probability for a sample to be selected in the training set","-tp", 1, true);
parser->AddOption("--StreamingNumberOfLines","Number of lined for each streaming block","-sl", 1, true);
parser->AddOption("--SOMMap","Output SOM map","-sm", 1, true);
parser->AddOption("--SizeX","X size of the SOM map","-sx", 1, true);
parser->AddOption("--SizeY","Y size of the SOM map","-sy", 1, true);
parser->AddOption("--NeighborhoodInitX","X initial neighborhood of the SOM map","-nx", 1, true);
parser->AddOption("--NeighborhoodInitY","Y initial neighborhood of the SOM map","-ny", 1, true);
parser->AddOption("--NumberOfIterations","Number of iterations of the SOM learning","-ni", 1, true);
parser->AddOption("--BetaInit","Initial beta value","-bi", 1, true);
parser->AddOption("--BetaFinal","Final beta value","-be", 1, true);
parser->AddOption("--InitValue","Initial value","-iv", 1, true);
typedef otb::CommandLineArgumentParseResult ParserResultType;
ParserResultType::Pointer parseResult = ParserResultType::New();
try
{
parser->ParseCommandLine(argc, argv, parseResult);
}
catch ( itk::ExceptionObject & err )
{
std::string descriptionException = err.GetDescription();
if (descriptionException.find("ParseCommandLine(): Help Parser") != std::string::npos)
{
return EXIT_SUCCESS;
}
if (descriptionException.find("ParseCommandLine(): Version Parser") != std::string::npos)
{
return EXIT_SUCCESS;
}
return EXIT_FAILURE;
}
// initiating random number generation
srand(time(NULL));
class SOMClassification : public Application
{
public:
/** Standard class typedefs. */
typedef SOMClassification Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
std::string infname = parseResult->GetInputImage();
std::string maskfname = parseResult->GetParameterString("--ValidityMask", 0);
std::string outfname = parseResult->GetOutputImage();
std::string somfname = parseResult->GetParameterString("--SOMMap", 0);
const unsigned int nbsamples = parseResult->GetParameterUInt("--MaxTrainingSetSize");
const double trainingProb =parseResult->GetParameterDouble("--TrainingSetProbability");
const unsigned int nbLinesForStreaming = parseResult->GetParameterUInt("--StreamingNumberOfLines");
const unsigned int sizeX = parseResult->GetParameterUInt("--SizeX");
const unsigned int sizeY = parseResult->GetParameterUInt("--SizeY");
const unsigned int neighInitX = parseResult->GetParameterUInt("--NeighborhoodInitX");
const unsigned int neighInitY= parseResult->GetParameterUInt("--NeighborhoodInitY");
const unsigned int nbIterations=parseResult->GetParameterUInt("--NumberOfIterations");
const double betaInit = parseResult->GetParameterDouble("--BetaInit");
const double betaEnd= parseResult->GetParameterDouble("--BetaFinal");
const float initValue = parseResult->GetParameterFloat("--InitValue");
/** Standard macro */
itkNewMacro(Self);
typedef float PixelType;
typedef unsigned short LabeledPixelType;
itkTypeMacro(SOMClassification, otb::Application);
typedef otb::VectorImage<PixelType, 2> ImageType;
typedef otb::Image<LabeledPixelType, 2> LabeledImageType;
typedef otb::ImageFileReader<ImageType> ImageReaderType;
typedef otb::ImageFileReader<LabeledImageType> LabeledImageReaderType;
typedef otb::StreamingImageFileWriter<LabeledImageType> WriterType;
/** Filters typedef */
typedef UInt16ImageType LabeledImageType;
typedef itk::VariableLengthVector<double> SampleType;
typedef itk::Statistics::EuclideanDistance<SampleType> DistanceType;
typedef otb::SOMMap<SampleType, DistanceType, 2> SOMMapType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef otb::SOM<ListSampleType, SOMMapType> EstimatorType;
typedef otb::ImageFileWriter<ImageType> SOMMapWriterType;
typedef otb::StreamingTraits<ImageType> StreamingTraitsType;
typedef otb::StreamingTraits<FloatVectorImageType> StreamingTraitsType;
typedef itk::ImageRegionSplitter<2> SplitterType;
typedef ImageType::RegionType RegionType;
typedef FloatVectorImageType::RegionType RegionType;
typedef itk::ImageRegionConstIterator<ImageType> IteratorType;
typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType;
typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType;
typedef itk::ImageRegionConstIterator<SOMMapType> SOMIteratorType;
typedef otb::SOMImageClassificationFilter<ImageType, LabeledImageType, SOMMapType> ClassificationFilterType;
ImageReaderType::Pointer reader = ImageReaderType::New();
LabeledImageReaderType::Pointer maskReader = LabeledImageReaderType::New();
reader->SetFileName(infname);
maskReader->SetFileName(maskfname);
/*******************************************/
/* Sampling data */
/*******************************************/
std::cout<<std::endl;
std::cout<<"-- SAMPLING DATA --"<<std::endl;
std::cout<<std::endl;
typedef otb::SOMImageClassificationFilter
<FloatVectorImageType, LabeledImageType, SOMMapType> ClassificationFilterType;
// Update input images information
reader->GenerateOutputInformation();
maskReader->GenerateOutputInformation();
if (reader->GetOutput()->GetLargestPossibleRegion()
!= maskReader->GetOutput()->GetLargestPossibleRegion()
)
private:
SOMClassification()
{
std::cerr<<"Mask image and input image have different sizes."<<std::endl;
return EXIT_FAILURE;
SetName("SOMClassification");
SetDescription("SOM image classification.");
// Documentation
SetDocName("SOM Classification Application");
SetDocLongDescription("Unsupervised Self Organizing Map image classification.");
SetDocLimitations("None");
SetDocAuthors("OTB-Team");
SetDocSeeAlso(" ");
AddDocTag(Tags::Segmentation);
AddDocTag(Tags::Learning);
}
RegionType largestRegion = reader->GetOutput()->GetLargestPossibleRegion();
// Setting up local streaming capabilities
SplitterType::Pointer splitter = SplitterType::New();
unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(reader->GetOutput(),
largestRegion,
splitter,
otb::SET_BUFFER_NUMBER_OF_LINES,
0, 0, nbLinesForStreaming);
std::cout<<"The images will be streamed into "<<numberOfStreamDivisions<<" parts."<<std::endl;
// Training sample lists
ListSampleType::Pointer sampleList = ListSampleType::New();
// Sample dimension and max dimension
unsigned int sampleSize = reader->GetOutput()->GetNumberOfComponentsPerPixel();
unsigned int totalSamples = 0;
std::cout<<"The following sample size will be used: "<<sampleSize<<std::endl;
std::cout<<std::endl;
// local streaming variables
unsigned int piece = 0;
RegionType streamingRegion;
while (totalSamples<nbsamples)
virtual ~SOMClassification()
{
piece = static_cast<unsigned int>(static_cast<double>(numberOfStreamDivisions)*rand()/(RAND_MAX));
streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion);
std::cout<<"Processing region: "<<streamingRegion<<std::endl;
reader->GetOutput()->SetRequestedRegion(streamingRegion);
reader->GetOutput()->PropagateRequestedRegion();
reader->GetOutput()->UpdateOutputData();
maskReader->GetOutput()->SetRequestedRegion(streamingRegion);
maskReader->GetOutput()->PropagateRequestedRegion();
maskReader->GetOutput()->UpdateOutputData();
IteratorType it(reader->GetOutput(), streamingRegion);
LabeledIteratorType maskIt(maskReader->GetOutput(), streamingRegion);
it.GoToBegin();
maskIt.GoToBegin();
unsigned int localNbSamples=0;
}
// Loop on the image
while (!it.IsAtEnd()&&!maskIt.IsAtEnd()&&(totalSamples<nbsamples))
void DoCreateParameters()
{
AddParameter(ParameterType_InputImage, "in", "InputImage");
SetParameterDescription("in", "Input image.");
AddParameter(ParameterType_OutputImage, "out", "OutputImage");
SetParameterDescription("out", "Output classified image.");
AddParameter(ParameterType_InputImage, "vm", "ValidityMask");
SetParameterDescription("vm", "Validity mask");
AddParameter(ParameterType_Float, "tp", "TrainingProbability");
SetParameterDescription("tp", "Probability for a sample to be selected in the training set");
AddParameter(ParameterType_Int, "ts", "TrainingSet");
SetParameterDescription("ts", "Maximum training set size");
AddParameter(ParameterType_Int, "sl", "StreamingLines");
SetParameterDescription("sl", "Number of lines in each streaming block (used during data sampling)");
AddParameter(ParameterType_OutputImage, "som", "SOM Map");
SetParameterDescription("som","Self-Organizing Map map");
AddParameter(ParameterType_Int, "sx", "SizeX");
SetParameterDescription("sx", "X size of the SOM map");
AddParameter(ParameterType_Int, "sy", "SizeY");
SetParameterDescription("sy", "Y size of the SOM map");
AddParameter(ParameterType_Int, "nx", "NeighborhoodX");
SetParameterDescription("nx", "X initial neighborhood of the SOM map");
AddParameter(ParameterType_Int, "ny", "NeighborhoodY");
SetParameterDescription("ny", "Y initial neighborhood of the SOM map");
AddParameter(ParameterType_Int, "ni", "NumberIteration");
SetParameterDescription("ni", "Number of iterations of the SOM learning");
AddParameter(ParameterType_Float, "bi", "BetaInit");
SetParameterDescription("bi", "Initial beta value");
AddParameter(ParameterType_Float, "bf", "BetaFinal");
SetParameterDescription("bf", "Final beta value");
AddParameter(ParameterType_Float, "iv", "InitialValue");
SetParameterDescription("iv", "Initial value (max weight)");
AddParameter(ParameterType_RAM, "ram", "Available RAM");
SetDefaultParameterInt("ram", 256);
MandatoryOff("ram");
// TODO : replace StreamingLines by RAM param ?
// Default parameters
SetDefaultParameterInt("sx", 32);
SetDefaultParameterInt("sy", 32);
SetDefaultParameterInt("nx", 10);
SetDefaultParameterInt("ny", 10);
SetDefaultParameterInt("ni", 5);
SetDefaultParameterFloat("bi",1.0);
SetDefaultParameterFloat("bf",0.1);
SetDefaultParameterFloat("iv",0.0);
// Doc example parameter settings
SetDocExampleParameterValue("in", "poupees_sub.png");
SetDocExampleParameterValue("out","poupees_classif.tif");
SetDocExampleParameterValue("vm", "BASELINE/leSOMPoupeesClassified.hdr");
SetDocExampleParameterValue("tp", "1.0");
SetDocExampleParameterValue("ts","16384");
SetDocExampleParameterValue("sl", "32");
SetDocExampleParameterValue("som", "poupees_map.hdr");
SetDocExampleParameterValue("sx", "32");
SetDocExampleParameterValue("sy", "32");
SetDocExampleParameterValue("nx", "10");
SetDocExampleParameterValue("ny", "10");
SetDocExampleParameterValue("ni", "5");
SetDocExampleParameterValue("bi", "1.0");
SetDocExampleParameterValue("bf", "0.1");
SetDocExampleParameterValue("iv", "0");
}
void DoUpdateParameters()
{
// Nothing to do
}
void DoExecute()
{
// If the current pixel is labeled
if (maskIt.Get()>0)
// initiating random number generation
itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer
randomGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New();
FloatVectorImageType::Pointer input = GetParameterImage("in");
LabeledImageType::Pointer mask = GetParameterImage<LabeledImageType>("vm");
/*******************************************/
/* Sampling data */
/*******************************************/
otbAppLogINFO("-- SAMPLING DATA --");
if (input->GetLargestPossibleRegion()
!= mask->GetLargestPossibleRegion())
{
otbAppLogFATAL("Mask image and input image have different sizes.");
}
RegionType largestRegion = input->GetLargestPossibleRegion();
// Setting up local streaming capabilities
SplitterType::Pointer splitter = SplitterType::New();
unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(input,
largestRegion,
splitter,
otb::SET_BUFFER_NUMBER_OF_LINES,
0, 0, GetParameterInt("sl"));
otbAppLogINFO("The images will be streamed into "<<numberOfStreamDivisions<<" parts.");
// Training sample lists
ListSampleType::Pointer sampleList = ListSampleType::New();
const double trainingProb = static_cast<double>(GetParameterFloat("tp"));
const unsigned int nbsamples = GetParameterInt("ts");
// Sample dimension and max dimension
unsigned int sampleSize = input->GetNumberOfComponentsPerPixel();
unsigned int totalSamples = 0;
otbAppLogINFO("The following sample size will be used: "<<sampleSize);
// local streaming variables
unsigned int piece = 0;
RegionType streamingRegion;
// create a random permutation to explore
itk::RandomPermutation randPerm(numberOfStreamDivisions);
unsigned int index = 0;
// TODO : maybe change the approach: at the moment, the sampling process is able to pick a sample twice or more
while (totalSamples < nbsamples)
{
piece = randPerm[index];
streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion);
//otbAppLogINFO("Processing region: "<<streamingRegion);
input->SetRequestedRegion(streamingRegion);
input->PropagateRequestedRegion();
input->UpdateOutputData();
mask->SetRequestedRegion(streamingRegion);
mask->PropagateRequestedRegion();
mask->UpdateOutputData();
IteratorType it(input, streamingRegion);
LabeledIteratorType maskIt(mask, streamingRegion);
it.GoToBegin();
maskIt.GoToBegin();
unsigned int localNbSamples=0;
// Loop on the image
while ( !it.IsAtEnd()
&& !maskIt.IsAtEnd()
&& (totalSamples<nbsamples))
{
if ((rand()<trainingProb*RAND_MAX))
// If the current pixel is labeled
if (maskIt.Get()>0)
{
SampleType newSample;
newSample.SetSize(sampleSize);
// build the sample
newSample.Fill(0);
for (unsigned int i = 0; i<sampleSize; ++i)
if (randomGen->GetVariateWithClosedRange() < trainingProb)
{
newSample[i]=it.Get()[i];
SampleType newSample;
newSample.SetSize(sampleSize);
// build the sample
newSample.Fill(0);
for (unsigned int i = 0; i<sampleSize; ++i)
{
newSample[i]=it.Get()[i];
}
// Update the sample lists
sampleList->PushBack(newSample);
++totalSamples;
++localNbSamples;
}
// Update the the sample lists
sampleList->PushBack(newSample);
++totalSamples;
++localNbSamples;
}
++it;
++maskIt;
}
++it;
++maskIt;
index++;
// we could break out of the while loop here, once the entire image has been streamed once
if (index == numberOfStreamDivisions) index = 0;
//otbAppLogINFO(""<<localNbSamples<<" samples added to the training set.");
}
std::cout<<localNbSamples<<" samples added to the training set."<<std::endl;
std::cout<<std::endl;
}
std::cout<<"The final training set contains "<<totalSamples<<" samples."<<std::endl;
std::cout<<std::endl;
std::cout<<"Data sampling completed."<<std::endl;
std::cout<<std::endl;
/*******************************************/
/* Learning */
/*******************************************/
std::cout<<"-- LEARNING --"<<std::endl;
EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetListSample(sampleList);
EstimatorType::SizeType size;
size[0]=sizeX;
size[1]=sizeY;
estimator->SetMapSize(size);
EstimatorType::SizeType radius;
radius[0] = neighInitX;
radius[1] = neighInitY;
estimator->SetNeighborhoodSizeInit(radius);
estimator->SetNumberOfIterations(nbIterations);
estimator->SetBetaInit(betaInit);
estimator->SetBetaEnd(betaEnd);
estimator->SetMaxWeight(initValue);
// estimator->SetRandomInit(true);
// estimator->SetSeed(time(NULL));
estimator->Update();
ImageType::Pointer vectormap = ImageType::New();
vectormap->SetRegions(estimator->GetOutput()->GetLargestPossibleRegion());
vectormap->SetNumberOfComponentsPerPixel(108);
vectormap->Allocate();
ImageType::PixelType black;
black.SetSize(108);
black.Fill(0);
vectormap->FillBuffer(black);
SOMIteratorType somIt(estimator->GetOutput(), estimator->GetOutput()->GetLargestPossibleRegion());
IteratorType vectorIt(vectormap, estimator->GetOutput()->GetLargestPossibleRegion());
somIt.GoToBegin();
vectorIt.GoToBegin();
while (!somIt.IsAtEnd() && !vectorIt.IsAtEnd())
{
for (unsigned int i = 0; i<somIt.Get().GetSize(); ++i)
{
vectorIt.Get()[i]=somIt.Get()[i];
otbAppLogINFO("The final training set contains "<<totalSamples<<" samples.");
/*******************************************/
/* Learning */
/*******************************************/
otbAppLogINFO("-- LEARNING --");
EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetListSample(sampleList);
EstimatorType::SizeType size;
size[0]=GetParameterInt("sx");
size[1]=GetParameterInt("sy");
estimator->SetMapSize(size);
EstimatorType::SizeType radius;
radius[0] = GetParameterInt("nx");
radius[1] = GetParameterInt("ny");
estimator->SetNeighborhoodSizeInit(radius);
estimator->SetNumberOfIterations(GetParameterInt("ni"));
estimator->SetBetaInit(GetParameterFloat("bi"));
estimator->SetBetaEnd(GetParameterFloat("bf"));
estimator->SetMaxWeight(GetParameterFloat("iv"));
AddProcess(estimator,"Learning");
estimator->Update();
m_SOMMap = estimator->GetOutput();
SetParameterOutputImage<DoubleVectorImageType>("som",m_SOMMap);
/*******************************************/
/* Classification */
/*******************************************/
otbAppLogINFO("-- CLASSIFICATION --");
m_Classifier = ClassificationFilterType::New();
m_Classifier->SetInput(input);
m_Classifier->SetInputMask(mask);
m_Classifier->SetMap(m_SOMMap);
AddProcess(m_Classifier,"Classification");
SetParameterOutputImage<LabeledImageType>("out",m_Classifier->GetOutput());
}
++somIt;
++vectorIt;
}
SOMMapWriterType::Pointer somWriter = SOMMapWriterType::New();
somWriter->SetFileName(somfname);
somWriter->SetInput(vectormap);
somWriter->Update();
std::cout<<std::endl;
std::cout<<"Learning completed."<<std::endl;
std::cout<<std::endl;
/*******************************************/
/* Classification */
/*******************************************/
SOMMapType::Pointer m_SOMMap;
ClassificationFilterType::Pointer m_Classifier;
};
std::cout<<"-- CLASSIFICATION --"<<std::endl;
std::cout<<std::endl;
ClassificationFilterType::Pointer classifier = ClassificationFilterType::New();
classifier->SetInput(reader->GetOutput());
classifier->SetInputMask(maskReader->GetOutput());
classifier->SetMap(estimator->GetOutput());
}
}
WriterType::Pointer writer = WriterType::New();
writer->SetFileName(outfname);
writer->SetInput(classifier->GetOutput());
writer->SetNumberOfDivisionsStrippedStreaming(numberOfStreamDivisions);
writer->Update();
OTB_APPLICATION_EXPORT(otb::Wrapper::SOMClassification)
std::cout<<"Classification completed."<<std::endl;
std::cout<<std::endl;
int main(int argc, char * argv[])
{
return EXIT_SUCCESS;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment