otbImageSVMClassifier.cxx 7.28 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
18 19
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
20

21
#include "itkVariableLengthVector.h"
22
#include "otbChangeLabelImageFilter.h"
23
#include "otbStandardWriterWatcher.h"
24 25 26
#include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleVectorImageFilter.h"
#include "otbSVMImageClassificationFilter.h"
27 28
#include "otbMultiToMonoChannelExtractROI.h"
#include "otbImageToVectorImageCastFilter.h"
29 30 31

namespace otb
{
32
namespace Wrapper
33 34
{

35
class ImageSVMClassifier : public Application
36
{
37 38 39 40 41 42
public:
  /** Standard class typedefs. */
  typedef ImageSVMClassifier            Self;
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
43

44 45
  /** Standard macro */
  itkNewMacro(Self);
46

47
  itkTypeMacro(ImageSVMClassifier, otb::Application);
48

49 50 51 52
  /** Filters typedef */
  typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType>            MeasurementType;
  typedef otb::StatisticsXMLFileReader<MeasurementType>                                 StatisticsReader;
  typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType>  RescalerType;
Julien Malik's avatar
STYLE  
Julien Malik committed
53 54 55 56
  typedef otb::SVMImageClassificationFilter<FloatVectorImageType, UInt8ImageType>       ClassificationFilterType;
  typedef ClassificationFilterType::Pointer                                             ClassificationFilterPointerType;
  typedef ClassificationFilterType::ModelType                                           ModelType;
  typedef ModelType::Pointer                                                            ModelPointerType;
57 58 59 60 61

private:
  ImageSVMClassifier()
  {
    SetName("ImageSVMClassifier");
62
    SetDescription("Performs a SVM classification of the input image according to a SVM model file.");
Cyrille Valladeau's avatar
Cyrille Valladeau committed
63 64
    
    // Documentation
65
    SetDocName("Image SVM Classification");
Julien Michel's avatar
Julien Michel committed
66
    SetDocLongDescription("This application performs a SVM image classification based on a SVM model file (*.svm extension) produced by the TrainSVMImagesClassifier application. Pixels of the output image will contain the class label decided by the SVM classifier. The input pixels can be optionnaly centered and reduced according to the statistics file produced by the ComputeImagesStatistics application. An optional input mask can be provided, in which case only input image pixels whose corresponding mask value is greater than 0 will be classified. The remaining of pixels will be given the label 0 in the output image.");
67 68

    SetDocLimitations("The input image must have the same type, order and number of bands than the images used to produce the statistics file and the SVM model file. If a statistics file was used during training by the TrainSVMImagesClassifier, it is mandatory to use the same statistics file for classification. If an input mask is used, its size must match the input image size.");
Cyrille Valladeau's avatar
Cyrille Valladeau committed
69
    SetDocAuthors("OTB-Team");
Julien Michel's avatar
Julien Michel committed
70
    SetDocSeeAlso("TrainSVMImagesClassifier, ValidateSVMImagesClassifier, ComputeImagesStatistics");
71
 
72
    AddDocTag(Tags::Learning);
73 74 75 76 77 78 79 80
  }

  virtual ~ImageSVMClassifier()
  {
  }

  void DoCreateParameters()
  {
81 82
    AddParameter(ParameterType_InputImage, "in",  "Input Image");
    SetParameterDescription( "in", "The input Image to classify.");
83

84 85
    AddParameter(ParameterType_InputImage,  "mask",   "Input Mask");
    SetParameterDescription( "mask", "The mask allows to restrict classification of the input image to the area where mask pixel values are greater than 0.");
86
    MandatoryOff("mask");
87

88 89
    AddParameter(ParameterType_Filename, "svm", "SVM Model file");
    SetParameterDescription("svm", "A SVM model file.");
90

91 92 93
    AddParameter(ParameterType_Filename, "imstat", "Statistics file");
    SetParameterDescription("imstat", "A XML file containing mean and standard deviation to center and reduce samples before classification.");
    MandatoryOff("imstat");
94
    
95
    AddParameter(ParameterType_OutputImage, "out",  "Output Image");
96
    SetParameterDescription( "out", "Output image labeled with class labels");
97
    SetParameterOutputImagePixelType( "out", ImagePixelType_uint8);
98 99 100 101

    AddParameter(ParameterType_RAM, "ram", "Available RAM");
    SetDefaultParameterInt("ram", 256);
    MandatoryOff("ram");
102 103 104 105

   // Doc example parameter settings
    SetDocExampleParameterValue("in", "QB_1_ortho.tif");
    SetDocExampleParameterValue("imstat", "clImageStatisticsQB1.xml");
106
    SetDocExampleParameterValue("svm", "clsvmModelQB1.svm");
107
    SetDocExampleParameterValue("out", "otbConcatenateImages.png uchar");
108 109 110 111 112 113
  }

  void DoUpdateParameters()
  {
    // Nothing to do here : all parameters are independent
  }
114

115 116 117 118 119 120 121
  void DoExecute()
  {
    // Load input image
    FloatVectorImageType::Pointer inImage = GetParameterImage("in");
    inImage->UpdateOutputInformation();

    // Load svm model
122
    otbAppLogINFO("Loading SVM model");
123 124
    m_ModelSVM = ModelType::New();
    m_ModelSVM->LoadModel(GetParameterString("svm").c_str());
125
    otbAppLogINFO("SVM model loaded");
126 127 128 129 130

    // Normalize input image (optional)
    StatisticsReader::Pointer  statisticsReader = StatisticsReader::New();
    MeasurementType  meanMeasurementVector;
    MeasurementType  stddevMeasurementVector;
131
    m_Rescaler = RescalerType::New();
132 133
    
    // Classify
134 135 136
    m_ClassificationFilter = ClassificationFilterType::New();
    m_ClassificationFilter->SetModel(m_ModelSVM);
  
137
    // Normalize input image if asked
138
    if(IsParameterEnabled("imstat")  )
139
      {
140
      otbAppLogINFO("Input image normalization activated.");
141 142 143 144
      // Load input image statistics
      statisticsReader->SetFileName(GetParameterString("imstat"));
      meanMeasurementVector   = statisticsReader->GetStatisticVectorByName("mean");
      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
145 146
      otbAppLogINFO( "mean used: " << meanMeasurementVector );
      otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector );
147
      // Rescale vector image
148 149 150
      m_Rescaler->SetScale(stddevMeasurementVector);
      m_Rescaler->SetShift(meanMeasurementVector);
      m_Rescaler->SetInput(inImage);
151
      
152
      m_ClassificationFilter->SetInput(m_Rescaler->GetOutput());
153 154 155
      }
    else
      {
156
      otbAppLogINFO("Input image normalization deactivated.");
157
      m_ClassificationFilter->SetInput(inImage);
158 159
      }
    
160
  
161
    if(IsParameterEnabled("mask"))
162
      {
163
      otbAppLogINFO("Using input mask");
164
      // Load mask image and cast into LabeledImageType
165
      UInt8ImageType::Pointer inMask = GetParameterUInt8Image("mask");
166
      
167
      m_ClassificationFilter->SetInputMask(inMask);
168
      }
169

Jonathan Guinet's avatar
Jonathan Guinet committed
170
    SetParameterOutputImage<UInt8ImageType>("out", m_ClassificationFilter->GetOutput());
171 172
  }

173 174 175
  ClassificationFilterType::Pointer m_ClassificationFilter;
  ModelPointerType m_ModelSVM;
  RescalerType::Pointer m_Rescaler;
176 177 178 179
};


}
180
}
181 182

OTB_APPLICATION_EXPORT(otb::Wrapper::ImageSVMClassifier)