diff --git a/Applications/Classification/CMakeLists.txt b/Applications/Classification/CMakeLists.txt index 7102c5e9ad26104cf5f31a16ca4dc487ac2569a4..b03faab27a640d333e647f2539a8cd50ea817164 100644 --- a/Applications/Classification/CMakeLists.txt +++ b/Applications/Classification/CMakeLists.txt @@ -1,3 +1,7 @@ OTB_CREATE_APPLICATION(NAME EstimateImagesStatistics SOURCES otbEstimateImagesStatistics.cxx LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters) + +OTB_CREATE_APPLICATION(NAME ImageSVMClassifier + SOURCES otbImageSVMClassifier.cxx + LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters) diff --git a/Applications/Classification/otbImageSVMClassifier.cxx b/Applications/Classification/otbImageSVMClassifier.cxx index 2b5328ef5381da1b85d9160f65a37a8aac91de66..a4a7da2775f142f87dd729aa4b842b06ff8f1a39 100644 --- a/Applications/Classification/otbImageSVMClassifier.cxx +++ b/Applications/Classification/otbImageSVMClassifier.cxx @@ -15,160 +15,164 @@ PURPOSE. See the above copyright notices for more information. =========================================================================*/ -#include "otbImageSVMClassifier.h" +#include "otbWrapperApplication.h" +#include "otbWrapperApplicationFactory.h" -#include <iostream> -#include "otbCommandLineArgumentParser.h" - -// otb basic -#include "otbImage.h" -#include "otbVectorImage.h" -#include "otbImageFileReader.h" -#include "otbStreamingImageFileWriter.h" +#include "itkVariableLengthVector.h" #include "otbChangeLabelImageFilter.h" #include "otbStandardWriterWatcher.h" - -// itk -#include "itkVariableLengthVector.h" - - -// Statistic XML Reader #include "otbStatisticsXMLFileReader.h" - -// Shift Scale Vector Image Filter #include "otbShiftScaleVectorImageFilter.h" - -// Classification filter #include "otbSVMImageClassificationFilter.h" - -#include "itkTimeProbe.h" -#include "otbStandardFilterWatcher.h" +#include "otbMultiToMonoChannelExtractROI.h" +#include "otbImageToVectorImageCastFilter.h" namespace otb { - -int ImageSVMClassifier::Describe(ApplicationDescriptor* descriptor) +namespace Wrapper { - descriptor->SetName("ImageSVMClassifier"); - descriptor->SetDescription("Perform SVM classification based a previous computed svm model to an new input image."); - descriptor->AddOption("InputImage", "A new image to classify", - "in", 1, true, ApplicationDescriptor::InputImage); - descriptor->AddOption("InputImageMask", "A mask associated with the new image to classify", - "inm", 1, false, ApplicationDescriptor::InputImage); - descriptor->AddOption("ImageStatistics", "a XML file containing mean and standard deviation of input images used to train svm model.", - "is", 1, false, ApplicationDescriptor::FileName); - descriptor->AddOption("SVMmodel", "Estimated model previously computed", - "svm", 1, true, ApplicationDescriptor::FileName); - descriptor->AddOption("OutputLabeledImage", "Output labeled image", - "out", 1, true, ApplicationDescriptor::OutputImage); - descriptor->AddOption("AvailableMemory","Set the maximum of available memory for the pipeline execution in mega bytes (optional, 256 by default)","ram", 1, false, otb::ApplicationDescriptor::Integer); - return EXIT_SUCCESS; -} -int ImageSVMClassifier::Execute(otb::ApplicationOptionsResult* parseResult) +class ImageSVMClassifier : public Application { +public: + /** Standard class typedefs. */ + typedef ImageSVMClassifier Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; - // Input Image - typedef float PixelType; - typedef unsigned char LabeledPixelType; - typedef otb::VectorImage<PixelType, 2> VectorImageType; - typedef otb::Image<LabeledPixelType, 2> LabeledImageType; - - typedef otb::ImageFileReader<VectorImageType> ReaderType; - typedef otb::ImageFileReader<LabeledImageType> LabeledReaderType; - typedef otb::StreamingImageFileWriter<LabeledImageType> WriterType; + /** Standard macro */ + itkNewMacro(Self); - typedef otb::PipelineMemoryPrintCalculator MemoryCalculatorType; + itkTypeMacro(ImageSVMClassifier, otb::Application); + /** Filters typedef */ // Statistic XML file Reader - typedef itk::VariableLengthVector<PixelType> MeasurementType; - typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader; - typedef otb::ShiftScaleVectorImageFilter<VectorImageType, VectorImageType> RescalerType; + typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType> MeasurementType; + typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader; + typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType> RescalerType; /// Classification typedefs - typedef otb::SVMImageClassificationFilter<VectorImageType, LabeledImageType> ClassificationFilterType; + typedef otb::SVMImageClassificationFilter<FloatVectorImageType, UInt8ImageType> ClassificationFilterType; typedef ClassificationFilterType::Pointer ClassificationFilterPointerType; typedef ClassificationFilterType::ModelType ModelType; typedef ModelType::Pointer ModelPointerType; -// typedef otb::ChangeLabelImageFilter<LabeledImageType, VectorImageType> ChangeLabelFilterType; -// typedef ChangeLabelFilterType::Pointer ChangeLabelFilterPointerType; - - - //-------------------------- - // Load input image - ReaderType::Pointer reader = ReaderType::New(); - reader->SetFileName(parseResult->GetParameterString("InputImage")); - reader->UpdateOutputInformation(); - - // Load svm model - ModelPointerType modelSVM = ModelType::New(); - modelSVM->LoadModel(parseResult->GetParameterString("SVMmodel").c_str()); - - //-------------------------- - // Normalize input image (optional) - StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); - MeasurementType meanMeasurementVector; - MeasurementType stddevMeasurementVector; - RescalerType::Pointer rescaler = RescalerType::New(); - - //-------------------------- - // Classify - ClassificationFilterPointerType classificationFilter = ClassificationFilterType::New(); - classificationFilter->SetModel(modelSVM); - - // Normalize input image - if (parseResult->IsOptionPresent("ImageStatistics")) - { - // Load input image statistics - statisticsReader->SetFileName(parseResult->GetParameterString("ImageStatistics")); - meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); - stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); - std::cout << "mean used: " << meanMeasurementVector << std::endl; - std::cout << "standard deviation used: " << stddevMeasurementVector << std::endl; - std::cout << "Shift and scale of the input image !" << std::endl; - // Rescale vector image - rescaler->SetScale(stddevMeasurementVector); - rescaler->SetShift(meanMeasurementVector); - rescaler->SetInput(reader->GetOutput()); - - classificationFilter->SetInput(rescaler->GetOutput()); - } - else - { - std::cout << "no shift and scale" << std::endl; - classificationFilter->SetInput(reader->GetOutput()); - } - - LabeledReaderType::Pointer readerMask = LabeledReaderType::New(); - //-------------------------- - // Set an input mask to exclude some areas (optional) - if (parseResult->IsOptionPresent("InputImageMask")) - { - readerMask->SetFileName(parseResult->GetParameterString("InputImageMask")); - readerMask->UpdateOutputInformation(); - classificationFilter->SetInputMask(readerMask->GetOutput()); - std::cout << "Set an input image mask!" << std::endl; - } - - //ChangeLabelFilterPointerType changeLabelFilter = ChangeLabelFilterType::New(); - //changeLabelFilter->SetInput(classificationFilter->GetOutput()); - //changeLabelFilter->SetNumberOfComponentsPerPixel(3); - - //-------------------------- - // Save labeled Image - WriterType::Pointer writer = WriterType::New(); - writer->SetInput(classificationFilter->GetOutput()); - writer->SetFileName(parseResult->GetParameterString("OutputLabeledImage")); - unsigned int ram = 256; - if (parseResult->IsOptionPresent("AvailableMemory")) - { - ram = parseResult->GetParameterUInt("AvailableMemory"); - } - writer->SetAutomaticTiledStreaming(ram); + + // Cast filter + // TODO: supress that !! + typedef MultiToMonoChannelExtractROI<FloatVectorImageType::InternalPixelType, + UInt8ImageType::PixelType> ExtractImageFilterType; + typedef ImageToVectorImageCastFilter<UInt8ImageType, FloatVectorImageType> CastImageFilterType; + +private: + ImageSVMClassifier() + { + SetName("ImageSVMClassifier"); + SetDescription("Perform SVM classification based a previous computed svm model to an new input image."); + } + + virtual ~ImageSVMClassifier() + { + } + + void DoCreateParameters() + { + AddParameter(ParameterType_InputImage, "in", "Input Image to classify"); + SetParameterDescription( "in", "Input Image to classify"); + + AddParameter(ParameterType_InputImage, "mask", "Input Mask to classify"); + SetParameterDescription( "mask", "A mask associated with the new image to classify"); + + AddParameter(ParameterType_Filename, "imstat", "Image statistics file."); + SetParameterDescription("imstat", "a XML file containing mean and standard deviation of input images used to train svm model."); + MandatoryOff("instat"); + + AddParameter(ParameterType_Filename, "svmmodel", "SVM Model."); + SetParameterDescription("svmmodel", "An estimated svm model previously computed"); - otb::StandardWriterWatcher watcher(writer,"Classification"); - writer->Update(); - return EXIT_SUCCESS; -} + AddParameter(ParameterType_OutputImage, "out", "Output Image"); + SetParameterDescription( "out", "Output labeled image"); + + } + + void DoUpdateParameters() + { + // Nothing to do here : all parameters are independent + } + void DoExecute() + { + otbAppLogDEBUG("Entering DoExecute"); + + // Load input image + FloatVectorImageType::Pointer inImage = GetParameterImage("in"); + inImage->UpdateOutputInformation(); + + // Load svm model + ModelPointerType modelSVM = ModelType::New(); + modelSVM->LoadModel(GetParameterString("svmmodel").c_str()); + + + // Normalize input image (optional) + StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); + MeasurementType meanMeasurementVector; + MeasurementType stddevMeasurementVector; + RescalerType::Pointer rescaler = RescalerType::New(); + + // Classify + ClassificationFilterType::Pointer classificationFilter = ClassificationFilterType::New(); + classificationFilter->SetModel(modelSVM); + + + // Normalize input image if asked + if( HasValue("imstat") ) + { + otbAppLogDEBUG("Input image normalization activated."); + // Load input image statistics + statisticsReader->SetFileName(GetParameterString("imstat")); + meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); + stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); + otbAppLogDEBUG( "mean used: " << meanMeasurementVector ); + otbAppLogDEBUG( "standard deviation used: " << stddevMeasurementVector ); + // Rescale vector image + rescaler->SetScale(stddevMeasurementVector); + rescaler->SetShift(meanMeasurementVector); + rescaler->SetInput(inImage); + + classificationFilter->SetInput(rescaler->GetOutput()); + } + else + { + otbAppLogDEBUG("Input image normalization deactivated."); + classificationFilter->SetInput(inImage); + } + + if( HasValue("mask") ) + { + otbAppLogDEBUG("Use input mask."); + // Load mask image and cast into LabeledImageType + FloatVectorImageType::Pointer inMask = GetParameterImage("mask"); + ExtractImageFilterType::Pointer extract = ExtractImageFilterType::New(); + extract->SetInput( inMask ); + extract->SetChannel(0); + extract->UpdateOutputInformation(); + + classificationFilter->SetInputMask(extract->GetOutput()); + } + + + CastImageFilterType::Pointer finalCast = CastImageFilterType::New(); + finalCast->SetInput( classificationFilter->GetOutput() ); + + SetParameterOutputImage("out", finalCast->GetOutput()); + } + + //itk::LightObject::Pointer m_FilterRef; +}; + + + +} } + +OTB_APPLICATION_EXPORT(otb::Wrapper::ImageSVMClassifier)