Commit 1c42b0f4 authored by Cyrille Valladeau's avatar Cyrille Valladeau

ENH: update the ImageSVMClassifier application to the new framework

parent 87ea3c84
OTB_CREATE_APPLICATION(NAME EstimateImagesStatistics
SOURCES otbEstimateImagesStatistics.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters)
OTB_CREATE_APPLICATION(NAME ImageSVMClassifier
SOURCES otbImageSVMClassifier.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters)
......@@ -15,160 +15,164 @@
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbImageSVMClassifier.h"
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include <iostream>
#include "otbCommandLineArgumentParser.h"
// otb basic
#include "otbImage.h"
#include "otbVectorImage.h"
#include "otbImageFileReader.h"
#include "otbStreamingImageFileWriter.h"
#include "itkVariableLengthVector.h"
#include "otbChangeLabelImageFilter.h"
#include "otbStandardWriterWatcher.h"
// itk
#include "itkVariableLengthVector.h"
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
// Shift Scale Vector Image Filter
#include "otbShiftScaleVectorImageFilter.h"
// Classification filter
#include "otbSVMImageClassificationFilter.h"
#include "itkTimeProbe.h"
#include "otbStandardFilterWatcher.h"
#include "otbMultiToMonoChannelExtractROI.h"
#include "otbImageToVectorImageCastFilter.h"
namespace otb
{
int ImageSVMClassifier::Describe(ApplicationDescriptor* descriptor)
namespace Wrapper
{
descriptor->SetName("ImageSVMClassifier");
descriptor->SetDescription("Perform SVM classification based a previous computed svm model to an new input image.");
descriptor->AddOption("InputImage", "A new image to classify",
"in", 1, true, ApplicationDescriptor::InputImage);
descriptor->AddOption("InputImageMask", "A mask associated with the new image to classify",
"inm", 1, false, ApplicationDescriptor::InputImage);
descriptor->AddOption("ImageStatistics", "a XML file containing mean and standard deviation of input images used to train svm model.",
"is", 1, false, ApplicationDescriptor::FileName);
descriptor->AddOption("SVMmodel", "Estimated model previously computed",
"svm", 1, true, ApplicationDescriptor::FileName);
descriptor->AddOption("OutputLabeledImage", "Output labeled image",
"out", 1, true, ApplicationDescriptor::OutputImage);
descriptor->AddOption("AvailableMemory","Set the maximum of available memory for the pipeline execution in mega bytes (optional, 256 by default)","ram", 1, false, otb::ApplicationDescriptor::Integer);
return EXIT_SUCCESS;
}
int ImageSVMClassifier::Execute(otb::ApplicationOptionsResult* parseResult)
class ImageSVMClassifier : public Application
{
public:
/** Standard class typedefs. */
typedef ImageSVMClassifier Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
// Input Image
typedef float PixelType;
typedef unsigned char LabeledPixelType;
typedef otb::VectorImage<PixelType, 2> VectorImageType;
typedef otb::Image<LabeledPixelType, 2> LabeledImageType;
typedef otb::ImageFileReader<VectorImageType> ReaderType;
typedef otb::ImageFileReader<LabeledImageType> LabeledReaderType;
typedef otb::StreamingImageFileWriter<LabeledImageType> WriterType;
/** Standard macro */
itkNewMacro(Self);
typedef otb::PipelineMemoryPrintCalculator MemoryCalculatorType;
itkTypeMacro(ImageSVMClassifier, otb::Application);
/** Filters typedef */
// Statistic XML file Reader
typedef itk::VariableLengthVector<PixelType> MeasurementType;
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
typedef otb::ShiftScaleVectorImageFilter<VectorImageType, VectorImageType> RescalerType;
typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType> MeasurementType;
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType> RescalerType;
/// Classification typedefs
typedef otb::SVMImageClassificationFilter<VectorImageType, LabeledImageType> ClassificationFilterType;
typedef otb::SVMImageClassificationFilter<FloatVectorImageType, UInt8ImageType> ClassificationFilterType;
typedef ClassificationFilterType::Pointer ClassificationFilterPointerType;
typedef ClassificationFilterType::ModelType ModelType;
typedef ModelType::Pointer ModelPointerType;
// typedef otb::ChangeLabelImageFilter<LabeledImageType, VectorImageType> ChangeLabelFilterType;
// typedef ChangeLabelFilterType::Pointer ChangeLabelFilterPointerType;
//--------------------------
// Load input image
ReaderType::Pointer reader = ReaderType::New();
reader->SetFileName(parseResult->GetParameterString("InputImage"));
reader->UpdateOutputInformation();
// Load svm model
ModelPointerType modelSVM = ModelType::New();
modelSVM->LoadModel(parseResult->GetParameterString("SVMmodel").c_str());
//--------------------------
// Normalize input image (optional)
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
RescalerType::Pointer rescaler = RescalerType::New();
//--------------------------
// Classify
ClassificationFilterPointerType classificationFilter = ClassificationFilterType::New();
classificationFilter->SetModel(modelSVM);
// Normalize input image
if (parseResult->IsOptionPresent("ImageStatistics"))
{
// Load input image statistics
statisticsReader->SetFileName(parseResult->GetParameterString("ImageStatistics"));
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
std::cout << "mean used: " << meanMeasurementVector << std::endl;
std::cout << "standard deviation used: " << stddevMeasurementVector << std::endl;
std::cout << "Shift and scale of the input image !" << std::endl;
// Rescale vector image
rescaler->SetScale(stddevMeasurementVector);
rescaler->SetShift(meanMeasurementVector);
rescaler->SetInput(reader->GetOutput());
classificationFilter->SetInput(rescaler->GetOutput());
}
else
{
std::cout << "no shift and scale" << std::endl;
classificationFilter->SetInput(reader->GetOutput());
}
LabeledReaderType::Pointer readerMask = LabeledReaderType::New();
//--------------------------
// Set an input mask to exclude some areas (optional)
if (parseResult->IsOptionPresent("InputImageMask"))
{
readerMask->SetFileName(parseResult->GetParameterString("InputImageMask"));
readerMask->UpdateOutputInformation();
classificationFilter->SetInputMask(readerMask->GetOutput());
std::cout << "Set an input image mask!" << std::endl;
}
//ChangeLabelFilterPointerType changeLabelFilter = ChangeLabelFilterType::New();
//changeLabelFilter->SetInput(classificationFilter->GetOutput());
//changeLabelFilter->SetNumberOfComponentsPerPixel(3);
//--------------------------
// Save labeled Image
WriterType::Pointer writer = WriterType::New();
writer->SetInput(classificationFilter->GetOutput());
writer->SetFileName(parseResult->GetParameterString("OutputLabeledImage"));
unsigned int ram = 256;
if (parseResult->IsOptionPresent("AvailableMemory"))
{
ram = parseResult->GetParameterUInt("AvailableMemory");
}
writer->SetAutomaticTiledStreaming(ram);
// Cast filter
// TODO: supress that !!
typedef MultiToMonoChannelExtractROI<FloatVectorImageType::InternalPixelType,
UInt8ImageType::PixelType> ExtractImageFilterType;
typedef ImageToVectorImageCastFilter<UInt8ImageType, FloatVectorImageType> CastImageFilterType;
private:
ImageSVMClassifier()
{
SetName("ImageSVMClassifier");
SetDescription("Perform SVM classification based a previous computed svm model to an new input image.");
}
virtual ~ImageSVMClassifier()
{
}
void DoCreateParameters()
{
AddParameter(ParameterType_InputImage, "in", "Input Image to classify");
SetParameterDescription( "in", "Input Image to classify");
AddParameter(ParameterType_InputImage, "mask", "Input Mask to classify");
SetParameterDescription( "mask", "A mask associated with the new image to classify");
AddParameter(ParameterType_Filename, "imstat", "Image statistics file.");
SetParameterDescription("imstat", "a XML file containing mean and standard deviation of input images used to train svm model.");
MandatoryOff("instat");
AddParameter(ParameterType_Filename, "svmmodel", "SVM Model.");
SetParameterDescription("svmmodel", "An estimated svm model previously computed");
otb::StandardWriterWatcher watcher(writer,"Classification");
writer->Update();
return EXIT_SUCCESS;
}
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output labeled image");
}
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
}
void DoExecute()
{
otbAppLogDEBUG("Entering DoExecute");
// Load input image
FloatVectorImageType::Pointer inImage = GetParameterImage("in");
inImage->UpdateOutputInformation();
// Load svm model
ModelPointerType modelSVM = ModelType::New();
modelSVM->LoadModel(GetParameterString("svmmodel").c_str());
// Normalize input image (optional)
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
RescalerType::Pointer rescaler = RescalerType::New();
// Classify
ClassificationFilterType::Pointer classificationFilter = ClassificationFilterType::New();
classificationFilter->SetModel(modelSVM);
// Normalize input image if asked
if( HasValue("imstat") )
{
otbAppLogDEBUG("Input image normalization activated.");
// Load input image statistics
statisticsReader->SetFileName(GetParameterString("imstat"));
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
otbAppLogDEBUG( "mean used: " << meanMeasurementVector );
otbAppLogDEBUG( "standard deviation used: " << stddevMeasurementVector );
// Rescale vector image
rescaler->SetScale(stddevMeasurementVector);
rescaler->SetShift(meanMeasurementVector);
rescaler->SetInput(inImage);
classificationFilter->SetInput(rescaler->GetOutput());
}
else
{
otbAppLogDEBUG("Input image normalization deactivated.");
classificationFilter->SetInput(inImage);
}
if( HasValue("mask") )
{
otbAppLogDEBUG("Use input mask.");
// Load mask image and cast into LabeledImageType
FloatVectorImageType::Pointer inMask = GetParameterImage("mask");
ExtractImageFilterType::Pointer extract = ExtractImageFilterType::New();
extract->SetInput( inMask );
extract->SetChannel(0);
extract->UpdateOutputInformation();
classificationFilter->SetInputMask(extract->GetOutput());
}
CastImageFilterType::Pointer finalCast = CastImageFilterType::New();
finalCast->SetInput( classificationFilter->GetOutput() );
SetParameterOutputImage("out", finalCast->GetOutput());
}
//itk::LightObject::Pointer m_FilterRef;
};
}
}
OTB_APPLICATION_EXPORT(otb::Wrapper::ImageSVMClassifier)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment