otbTrainImagesClassifier.cxx 9.81 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20

21
#include "otbTrainImagesBase.h"
22 23 24 25 26 27

namespace otb
{
namespace Wrapper
{

28
class TrainImagesClassifier : public TrainImagesBase
29 30
{
public:
31
  typedef TrainImagesClassifier         Self;
32
  typedef TrainImagesBase               Superclass;
33
  typedef itk::SmartPointer<Self>       Pointer;
34
  typedef itk::SmartPointer<const Self> ConstPointer;
35 36
  itkNewMacro( Self )
  itkTypeMacro( Self, Superclass )
37

38
  void DoInit() override
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
  {
    SetName( "TrainImagesClassifier" );
    SetDescription( "Train a classifier from multiple pairs of images and training vector data." );

    // Documentation
    SetDocName( "Train a classifier from multiple images" );
    SetDocLongDescription(
            "This application performs a classifier training from multiple pairs of input images and training vector data. "
                    "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by "
                    "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field "
                    "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation "
                    "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio "
                    "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and "
                    "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the "
                    "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. "
                    "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels"
55 56 57
                    " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM, OpenCV Machine Learning "
                    "(2.3.1 and later), and Shark ML. The output of this application is a text model file, whose format corresponds to the "
                    "ML model type chosen. There is no image nor vector data output." );
58 59 60 61 62 63 64 65 66 67
    SetDocLimitations( "None" );
    SetDocAuthors( "OTB-Team" );
    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );

    AddDocTag( Tags::Learning );

    // Perform initialization
    ClearApplications();
    InitIO();
    InitSampling();
68
    InitClassification();
69

70
    AddDocTag( Tags::Learning );
71 72

    // Doc example parameter settings
73 74 75 76 77 78 79 80 81 82 83 84 85
    SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
    SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
    SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
    SetDocExampleParameterValue( "sample.mv", "100" );
    SetDocExampleParameterValue( "sample.mt", "100" );
    SetDocExampleParameterValue( "sample.vtr", "0.5" );
    SetDocExampleParameterValue( "sample.vfn", "Class" );
    SetDocExampleParameterValue( "classifier", "libsvm" );
    SetDocExampleParameterValue( "classifier.libsvm.k", "linear" );
    SetDocExampleParameterValue( "classifier.libsvm.c", "1" );
    SetDocExampleParameterValue( "classifier.libsvm.opt", "false" );
    SetDocExampleParameterValue( "io.out", "svmModelQB1.txt" );
    SetDocExampleParameterValue( "io.confmatout", "svmConfusionMatrixQB1.csv" );
86

87
    SetOfficialDocLink();
88 89
  }

90
  void DoUpdateParameters() override
91
  {
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113
    if( HasValue( "io.vd" ) && IsParameterEnabled( "io.vd" ))
      {
      UpdatePolygonClassStatisticsParameters();
      }
  }

  /**
   * Select and Extract samples for validation with computed statistics and rates.
   * Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided.
   * If no dedicated validation is provided the training is split corresponding to the sample.vtr parameter,
   * in this case if no vector data have been provided, the training rates and statistics are computed
   * on the selection and extraction training result.
   * fileNames.sampleOutputs contains training data and after an ExtractValidationData training data will
   * be split to fileNames.sampleTrainOutputs.
   * \param imageList
   * \param fileNames
   * \param validationVectorFileList
   * \param rates
   * \param HasInputVector
   */
  void ExtractValidationData(FloatVectorImageListType *imageList, TrainFileNamesHandler& fileNames,
                             std::vector<std::string> validationVectorFileList,
114
                             const SamplingRates& rates, bool itkNotUsed(HasInputVector) )
115 116 117 118 119 120
  {
    if( !validationVectorFileList.empty() ) // Compute class statistics and sampling rate of validation data if provided.
      {
      ComputePolygonStatistics( imageList, validationVectorFileList, fileNames.polyStatValidOutputs );
      ComputeSamplingRate( fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv );
      SelectAndExtractValidationSamples( fileNames, imageList, validationVectorFileList );
121

122
      fileNames.sampleTrainOutputs = fileNames.sampleOutputs;
123 124 125 126 127
      }
    else if(GetParameterFloat("sample.vtr") != 0.0)// Split training data to validation
      {
      SplitTrainingToValidationSamples( fileNames, imageList );
      }
128
    else // Update sampleTrainOutputs and clear sampleValidOutputs
129 130
      {
      fileNames.sampleTrainOutputs = fileNames.sampleOutputs;
131 132 133 134

      // Corner case where no dedicated validation set is provided and split ratio is set to 0 (all samples for training)
      // In this case SampleValidOutputs should be cleared
      fileNames.sampleValidOutputs.clear();
135 136 137 138 139 140 141 142 143 144 145 146 147 148
      }
  }

  /**
   * Extract Training data depending if input vector is provided
   * \param imageList list of the image
   * \param fileNames handler that contain filenames
   * \param vectorFileList input vector file list (if provided
   * \param rates
   */
  void ExtractTrainData(FloatVectorImageListType *imageList, const TrainFileNamesHandler& fileNames,
                        std::vector<std::string> vectorFileList,
                        const SamplingRates& rates)
  {
149 150
//    if( !vectorFileList.empty() ) // Select and Extract samples for training with computed statistics and rates
//      {
151 152
      ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs );
      ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt );
153
      SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, Superclass::CLASS );
154 155 156 157 158
//      }
//    else // Select training samples base on geometric sampling if no input vector is provided
//      {
//      SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC, "fid" );
//      }
159 160
  }

161

162
  void DoExecute() override
163 164
  {
    TrainFileNamesHandler fileNames;
165
    std::vector<std::string> vectorFileList;
166
    FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
167 168 169 170 171
    bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
    if(HasInputVector)
      vectorFileList = GetParameterStringList( "io.vd" );


172 173
    unsigned long nbInputs = imageList->Size();

174
    if( !HasInputVector )
175 176 177 178 179
      {
      otbAppLogFATAL( "Missing input vector data files" );
      }

    if( !vectorFileList.empty() && nbInputs > vectorFileList.size() )
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
      {
      otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
      }

    // check if validation vectors are given
    std::vector<std::string> validationVectorFileList;
    bool dedicatedValidation = false;
    if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
      {
      validationVectorFileList = GetParameterStringList( "io.valid" );
      if( nbInputs > validationVectorFileList.size() )
        {
        otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
        }

      dedicatedValidation = true;
      }

    fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );

    // Compute final maximum sampling rates for both training and validation samples
    SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );

203 204
    ExtractTrainData(imageList, fileNames, vectorFileList, rates);
    ExtractValidationData(imageList, fileNames, validationVectorFileList, rates, HasInputVector);
205 206

    // Then train the model with extracted samples
207
    TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs );
208 209

    // cleanup
210
    if( GetParameterInt( "cleanup" ) )
211 212 213 214 215 216
      {
      otbAppLogINFO( <<"Final clean-up ..." );
      fileNames.clear();
      }
  }

217 218 219 220 221
private :

  void UpdatePolygonClassStatisticsParameters()
  {
    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
222
    GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0]);
223 224 225
    UpdateInternalParameters( "polystat" );
  }

226
};
227 228 229

} // end namespace Wrapper
} // end namespace otb
230

231
OTB_APPLICATION_EXPORT( otb::Wrapper::TrainImagesClassifier )