otbTrainImagesClassifier.cxx 21.3 KB
Newer Older
1
/*=========================================================================
2
3
4
5
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$
6
7


8
9
 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.
10
11


12
13
14
 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.
15

16
 =========================================================================*/
17
#include "otbWrapperCompositeApplication.h"
18
#include "otbWrapperApplicationFactory.h"
19

20
21
#include "otbOGRDataToSamplePositionFilter.h"
#include "otbSamplingRateCalculator.h"
22
23
24
25
26
27

namespace otb
{
namespace Wrapper
{

28
class TrainImagesClassifier: public CompositeApplication
29
30
31
32
{
public:
  /** Standard class typedefs. */
  typedef TrainImagesClassifier Self;
33
  typedef CompositeApplication Superclass;
34
35
36
37
38
39
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self)

40
  itkTypeMacro(TrainImagesClassifier, otb::Wrapper::CompositeApplication)
41

42
43
44
45
46
47
48
49
  /** filters typedefs*/
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::PeriodicSampler>                           PeriodicSamplerType;

  typedef otb::SamplingRateCalculator::MapRateType  MapRateType;

50
51
52
protected:

private:
53

54
55
56
57
58
bool RemoveFile(std::string &filePath)
{
  bool res = true;
  if(itksys::SystemTools::FileExists(filePath.c_str()))
    {
59
60
61
62
63
64
65
66
67
68
69
    size_t posExt = filePath.rfind('.');
    if (posExt != std::string::npos &&
        filePath.compare(posExt,std::string::npos,".shp") == 0)
      {
      std::string shxPath = filePath.substr(0,posExt) + std::string(".shx");
      std::string dbfPath = filePath.substr(0,posExt) + std::string(".dbf");
      std::string prjPath = filePath.substr(0,posExt) + std::string(".prj");
      RemoveFile(shxPath);
      RemoveFile(dbfPath);
      RemoveFile(prjPath);
      }
70
71
72
73
74
75
76
77
78
    res = itksys::SystemTools::RemoveFile(filePath.c_str());
    if (!res)
      {
      otbAppLogINFO(<<"Unable to remove file  "<<filePath);
      }
    }
  return res;
}

79
void DoInit() ITK_OVERRIDE
80
81
82
{
  SetName("TrainImagesClassifier");
  SetDescription(
83
    "Train a classifier from multiple pairs of images and training vector data.");
84
85

  // Documentation
86
  SetDocName("Train a classifier from multiple images");
87
88
89
90
91
  SetDocLongDescription(
    "This application performs a classifier training from multiple pairs of input images and training vector data. "
    "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by "
    "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field "
    "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation "
92
93
    "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio "
    "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and "
94
95
96
    "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the "
    "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. "
    "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels"
97
98
    " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning "
    "(2.3.1 and later).");
99
100
101
102
  SetDocLimitations("None");
  SetDocAuthors("OTB-Team");
  SetDocSeeAlso("OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html ");

103
104
  AddDocTag(Tags::Learning);

105
  ClearApplications();
106
107
108
109
110
111
  AddApplication("PolygonClassStatistics", "polystat","Polygon analysis");
  AddApplication("MultiImageSamplingRate", "rates", "Sampling rates");
  AddApplication("SampleSelection", "select", "Sample selection");
  AddApplication("SampleExtraction","extraction", "Sample extraction");
  AddApplication("TrainVectorClassifier", "training", "Model training");

112
113
  //Group IO
  AddParameter(ParameterType_Group, "io", "Input and output data");
114
  SetParameterDescription("io", "This group of parameters allows setting input and output data.");
115

116
117
118
119
120
  AddParameter(ParameterType_InputImageList, "io.il", "Input Image List");
  SetParameterDescription("io.il", "A list of input images.");
  AddParameter(ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List");
  SetParameterDescription("io.vd", "A list of vector data to select the training samples.");

121
122
123
124
  AddParameter(ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List");
  SetParameterDescription("io.valid", "A list of vector data to select the training samples.");
  MandatoryOff("io.valid");

125
126
127
128
  ShareParameter("io.imstat","training.io.stats");
  ShareParameter("io.confmatout","training.io.confmatout");
  ShareParameter("io.out","training.io.out");

129
  ShareParameter("elev","polystat.elev");
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

  // Sampling settings
  AddParameter(ParameterType_Group, "sample", "Training and validation samples parameters");
  SetParameterDescription("sample",
    "This group of parameters allows you to set training and validation sample lists parameters.");
  AddParameter(ParameterType_Int, "sample.mt", "Maximum training sample size per class");
  SetDefaultParameterInt("sample.mt", 1000);
  SetParameterDescription("sample.mt", "Maximum size per class (in pixels) of "
    "the training sample list (default = 1000) (no limit = -1). If equal to -1,"
    " then the maximal size of the available training sample list per class "
    "will be equal to the surface area of the smallest class multiplied by the"
    " training sample ratio.");
  AddParameter(ParameterType_Int, "sample.mv", "Maximum validation sample size per class");
  SetDefaultParameterInt("sample.mv", 1000);
  SetParameterDescription("sample.mv", "Maximum size per class (in pixels) of "
    "the validation sample list (default = 1000) (no limit = -1). If equal to -1,"
    " then the maximal size of the available validation sample list per class "
    "will be equal to the surface area of the smallest class multiplied by the "
    "validation sample ratio.");
  AddParameter(ParameterType_Int, "sample.bm", "Bound sample number by minimum");
  SetDefaultParameterInt("sample.bm", 1);
  SetParameterDescription("sample.bm", "Bound the number of samples for each "
    "class by the number of available samples by the smaller class. Proportions "
    "between training and validation are respected. Default is true (=1).");
  AddParameter(ParameterType_Float, "sample.vtr", "Training and validation sample ratio");
  SetParameterDescription("sample.vtr",
    "Ratio between training and validation samples (0.0 = all training, 1.0 = "
    "all validation) (default = 0.5).");
158
  SetParameterFloat("sample.vtr",0.5, false);
159
160
  SetMaximumParameterFloatValue("sample.vtr",1.0);
  SetMinimumParameterFloatValue("sample.vtr",0.0);
161

162
  ShareParameter("sample.vfn","polystat.field");
163
164
165
166

  // hide sampling parameters
  //ShareParameter("sample.strategy","rates.strategy");
  //ShareParameter("sample.mim","rates.mim");
167
168
169
170
171
172
173
174
175

  // Classifier settings
  ShareParameter("classifier","training.classifier");

  ShareParameter("rand","training.rand");

  // Synchronization between applications
  Connect("select.field", "polystat.field");
  Connect("select.layer", "polystat.layer");
176
  Connect("select.elev",  "polystat.elev");
177
178

  Connect("extraction.in",    "select.in");
179
  Connect("extraction.vec",   "select.out");
180
181
182
183
  Connect("extraction.field", "polystat.field");
  Connect("extraction.layer", "polystat.layer");

  Connect("training.cfield", "polystat.field");
184

185
186
187
188
  ShareParameter("ram","polystat.ram");
  Connect("select.ram", "polystat.ram");
  Connect("extraction.ram", "polystat.ram");

189
190
  Connect("select.rand", "training.rand");

191
192
193
194
195
  AddParameter(ParameterType_Empty,"cleanup","Temporary files cleaning");
  EnableParameter("cleanup");
  SetParameterDescription("cleanup","If activated, the application will try to clean all temporary files it created");
  MandatoryOff("cleanup");

196
197
198
199
200
201
202
203
204
205
206
207
208
209
  // Doc example parameter settings
  SetDocExampleParameterValue("io.il", "QB_1_ortho.tif");
  SetDocExampleParameterValue("io.vd", "VectorData_QB1.shp");
  SetDocExampleParameterValue("io.imstat", "EstimateImageStatisticsQB1.xml");
  SetDocExampleParameterValue("sample.mv", "100");
  SetDocExampleParameterValue("sample.mt", "100");
  SetDocExampleParameterValue("sample.vtr", "0.5");
  SetDocExampleParameterValue("sample.vfn", "Class");
  SetDocExampleParameterValue("classifier", "libsvm");
  SetDocExampleParameterValue("classifier.libsvm.k", "linear");
  SetDocExampleParameterValue("classifier.libsvm.c", "1");
  SetDocExampleParameterValue("classifier.libsvm.opt", "false");
  SetDocExampleParameterValue("io.out", "svmModelQB1.txt");
  SetDocExampleParameterValue("io.confmatout", "svmConfusionMatrixQB1.csv");
210
}
211

212
void DoUpdateParameters() ITK_OVERRIDE
213
{
214
215
216
  if ( HasValue("io.vd") )
    {
    std::vector<std::string> vectorFileList = GetParameterStringList("io.vd");
217
    GetInternalApplication("polystat")->SetParameterString("vec",vectorFileList[0], false);
218
219
    UpdateInternalParameters("polystat");
    }
220
}
221

222
void DoExecute() ITK_OVERRIDE
223
{
224
225
226
  FloatVectorImageListType* imageList = GetParameterImageList("io.il");
  std::vector<std::string> vectorFileList = GetParameterStringList("io.vd");
  unsigned int nbInputs = imageList->Size();
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
  if (nbInputs > vectorFileList.size())
    {
    otbAppLogFATAL("Missing input vector data files to match number of images ("<<nbInputs<<").");
    }

  // check if validation vectors are given
  std::vector<std::string> validationVectorFileList;
  bool dedicatedValidation = false;
  if (IsParameterEnabled("io.valid") && HasValue("io.valid"))
    {
    dedicatedValidation = true;
    validationVectorFileList = GetParameterStringList("io.valid");
    if (nbInputs > validationVectorFileList.size())
      {
      otbAppLogFATAL("Missing validation vector data files to match number of images ("<<nbInputs<<").");
      }
    }

245
  // Prepare temporary file names
246
  std::string outModel(GetParameterString("io.out"));
247
248
249
250
  std::vector<std::string> polyStatTrainOutputs;
  std::vector<std::string> polyStatValidOutputs;
  std::vector<std::string> ratesTrainOutputs;
  std::vector<std::string> ratesValidOutputs;
251
252
253
  std::vector<std::string> sampleOutputs;
  std::vector<std::string> sampleTrainOutputs;
  std::vector<std::string> sampleValidOutputs;
254
255
256
257
258
259
260
261
262
263
  std::string rateTrainOut;
  if (dedicatedValidation)
    {
    rateTrainOut = outModel + "_ratesTrain.csv";
    }
  else
    {
    rateTrainOut = outModel + "_rates.csv";
    }
  std::string rateValidOut(outModel + "_ratesValid.csv");
264
265
266
267
268
  for (unsigned int i=0 ; i<nbInputs ; i++)
    {
    std::ostringstream oss;
    oss <<i+1;
    std::string strIndex(oss.str());
269
270
271
272
273
274
    if (dedicatedValidation)
      {
      polyStatTrainOutputs.push_back(outModel + "_statsTrain_" + strIndex + ".xml");
      polyStatValidOutputs.push_back(outModel + "_statsValid_" + strIndex + ".xml");
      ratesTrainOutputs.push_back(outModel + "_ratesTrain_" + strIndex + ".csv");
      ratesValidOutputs.push_back(outModel + "_ratesValid_" + strIndex + ".csv");
275
      sampleOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp");
276
277
278
279
280
      }
    else
      {
      polyStatTrainOutputs.push_back(outModel + "_stats_" + strIndex + ".xml");
      ratesTrainOutputs.push_back(outModel + "_rates_" + strIndex + ".csv");
281
      sampleOutputs.push_back(outModel + "_samples_" + strIndex + ".shp");
282
      }
283
284
    sampleTrainOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp");
    sampleValidOutputs.push_back(outModel + "_samplesValid_" + strIndex + ".shp");
285
    }
286

287
  // ---------------------------------------------------------------------------
288
289
290
291
  // Polygons stats
  for (unsigned int i=0 ; i<nbInputs ; i++)
    {
    GetInternalApplication("polystat")->SetParameterInputImage("in",imageList->GetNthElement(i));
292
293
    GetInternalApplication("polystat")->SetParameterString("vec",vectorFileList[i], false);
    GetInternalApplication("polystat")->SetParameterString("out",polyStatTrainOutputs[i], false);
294
    ExecuteInternal("polystat");
295
296
297
    // analyse polygons given for validation
    if (dedicatedValidation)
      {
298
299
      GetInternalApplication("polystat")->SetParameterString("vec",validationVectorFileList[i], false);
      GetInternalApplication("polystat")->SetParameterString("out",polyStatValidOutputs[i], false);
300
301
      ExecuteInternal("polystat");
      }
302
303
    }

304
305
  // ---------------------------------------------------------------------------
  // Compute sampling rates
306
  GetInternalApplication("rates")->SetParameterString("mim","proportional", false);
307
  double vtr = GetParameterFloat("sample.vtr");
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
  long mt = GetParameterInt("sample.mt");
  long mv = GetParameterInt("sample.mv");
  // compute final maximum training and final maximum validation
  // By default take all samples (-1 means all samples)
  long fmt = -1;
  long fmv = -1;
  if (GetParameterInt("sample.bm") == 0)
    {
    if (dedicatedValidation)
      {
      // fmt and fmv will be used separately
      fmt = mt;
      fmv = mv;
      if (mt > -1 && mv <= -1 && vtr < 0.99999)
        {
        fmv = static_cast<long>((double) mt * vtr / (1.0 - vtr));
        }
      if (mt <= -1 && mv > -1 && vtr > 0.00001)
        {
        fmt = static_cast<long>((double) mv * (1.0 - vtr) / vtr);
        }
      }
    else
      {
      // only fmt will be used for both training and validation samples
      // So we try to compute the total number of samples given input
      // parameters mt, mv and vtr.
      if (mt > -1 && mv > -1)
        {
        fmt = mt + mv;
        }
      if (mt > -1 && mv <= -1 && vtr < 0.99999)
        {
        fmt = static_cast<long>((double) mt / (1.0 - vtr));
        }
      if (mt <= -1 && mv > -1 && vtr > 0.00001)
        {
        fmt = static_cast<long>((double) mv / vtr);
        }
      }
    }

  // Sampling rates for training
351
352
  GetInternalApplication("rates")->SetParameterStringList("il",polyStatTrainOutputs, false);
  GetInternalApplication("rates")->SetParameterString("out",rateTrainOut, false);
353
354
  if (GetParameterInt("sample.bm") != 0)
    {
355
    GetInternalApplication("rates")->SetParameterString("strategy","smallest", false);
356
357
358
    }
  else
    {
359
    if (fmt > -1)
360
      {
361
      GetInternalApplication("rates")->SetParameterString("strategy","constant", false);
362
      GetInternalApplication("rates")->SetParameterInt("strategy.constant.nb",fmt);
363
364
365
      }
    else
      {
366
      GetInternalApplication("rates")->SetParameterString("strategy","all", false);
367
368
369
370
371
372
      }
    }
  ExecuteInternal("rates");
  // Sampling rates for validation
  if (dedicatedValidation)
    {
373
374
    GetInternalApplication("rates")->SetParameterStringList("il",polyStatValidOutputs, false);
    GetInternalApplication("rates")->SetParameterString("out",rateValidOut, false);
375
376
    if (GetParameterInt("sample.bm") != 0)
      {
377
      GetInternalApplication("rates")->SetParameterString("strategy","smallest", false);
378
379
380
381
      }
    else
      {
      if (fmv > -1)
382
        {
383
        GetInternalApplication("rates")->SetParameterString("strategy","constant", false);
384
        GetInternalApplication("rates")->SetParameterInt("strategy.constant.nb",fmv);
385
386
387
        }
      else
        {
388
        GetInternalApplication("rates")->SetParameterString("strategy","all", false);
389
390
        }
      }
391
    ExecuteInternal("rates");
392
393
    }

394
  // ---------------------------------------------------------------------------
395
  // Select & extract samples
396
  GetInternalApplication("select")->SetParameterString("sampler", "periodic", false);
397
  GetInternalApplication("select")->SetParameterInt("sampler.periodic.jitter",50);
398
399
400
  GetInternalApplication("select")->SetParameterString("strategy","byclass", false);
  GetInternalApplication("extraction")->SetParameterString("outfield", "prefix", false);
  GetInternalApplication("extraction")->SetParameterString("outfield.prefix.name","value_", false);
401
402
403
  for (unsigned int i=0 ; i<nbInputs ; i++)
    {
    GetInternalApplication("select")->SetParameterInputImage("in",imageList->GetNthElement(i));
404
405
406
407
    GetInternalApplication("select")->SetParameterString("vec",vectorFileList[i], false);
    GetInternalApplication("select")->SetParameterString("out",sampleOutputs[i], false);
    GetInternalApplication("select")->SetParameterString("instats",polyStatTrainOutputs[i], false);
    GetInternalApplication("select")->SetParameterString("strategy.byclass.in",ratesTrainOutputs[i], false);
408
409
410
411
412
    // select sample positions
    ExecuteInternal("select");
    // extract sample descriptors
    ExecuteInternal("extraction");

413
    if (dedicatedValidation)
414
      {
415
416
417
418
      GetInternalApplication("select")->SetParameterString("vec",validationVectorFileList[i], false);
      GetInternalApplication("select")->SetParameterString("out",sampleValidOutputs[i], false);
      GetInternalApplication("select")->SetParameterString("instats",polyStatValidOutputs[i], false);
      GetInternalApplication("select")->SetParameterString("strategy.byclass.in",ratesValidOutputs[i], false);
419
420
421
422
      // select sample positions
      ExecuteInternal("select");
      // extract sample descriptors
      ExecuteInternal("extraction");
423
      }
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
    else
      {
      // Split between training and validation
      ogr::DataSource::Pointer source = ogr::DataSource::New(sampleOutputs[i], ogr::DataSource::Modes::Read);
      ogr::DataSource::Pointer destTrain = ogr::DataSource::New(sampleTrainOutputs[i], ogr::DataSource::Modes::Overwrite);
      ogr::DataSource::Pointer destValid = ogr::DataSource::New(sampleValidOutputs[i], ogr::DataSource::Modes::Overwrite);
      // read sampling rates from ratesTrainOutputs[i]
      SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
      rateCalculator->Read(ratesTrainOutputs[i]);
      // Compute sampling rates for train and valid
      const MapRateType &inputRates = rateCalculator->GetRatesByClass();
      MapRateType trainRates;
      MapRateType validRates;
      otb::SamplingRateCalculator::TripletType tpt;
      for (MapRateType::const_iterator it = inputRates.begin() ;
           it != inputRates.end() ;
           ++it)
        {
        unsigned long total = std::min(it->second.Required,it->second.Tot );
        unsigned long neededValid = static_cast<unsigned long>((double) total * vtr );
        unsigned long neededTrain = total - neededValid;
        tpt.Tot = total;
        tpt.Required = neededTrain;
        tpt.Rate = (1.0 - vtr);
        trainRates[it->first] = tpt;
        tpt.Tot = neededValid;
        tpt.Required = neededValid;
        tpt.Rate = 1.0;
        validRates[it->first] = tpt;
        }
454

455
456
457
458
459
460
461
462
463
      // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
      PeriodicSamplerType::SamplerParameterType param;
      param.Offset = 0;
      param.MaxJitter = 0;
      PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
      splitter->SetInput(imageList->GetNthElement(i));
      splitter->SetOGRData(source);
      splitter->SetOutputPositionContainerAndRates(destTrain, trainRates, 0);
      splitter->SetOutputPositionContainerAndRates(destValid, validRates, 1);
464
      splitter->SetFieldName(this->GetParameterStringList("sample.vfn")[0]);
465
      splitter->SetLayerIndex(0);
466
      splitter->SetOriginFieldName(std::string(""));
467
468
469
470
471
      splitter->SetSamplerParameters(param);
      splitter->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
      AddProcess(splitter->GetStreamer(),"Split samples between training and validation...");
      splitter->Update();
      }
472
473
    }

474
475
  // ---------------------------------------------------------------------------
  // Train model
476
477
  GetInternalApplication("training")->SetParameterStringList("io.vd",sampleTrainOutputs, false);
  GetInternalApplication("training")->SetParameterStringList("valid.vd",sampleValidOutputs, false);
478
479
480
481
482
483
484
485
486
487
  UpdateInternalParameters("training");
  // set field names
  FloatVectorImageType::Pointer image = imageList->GetNthElement(0);
  unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
  std::vector<std::string> selectedNames;
  for (unsigned int i=0 ; i<nbBands ; i++)
    {
    std::ostringstream oss;
    oss << i;
    selectedNames.push_back("value_"+oss.str());
488
    }
489
  GetInternalApplication("training")->SetParameterStringList("feat",selectedNames, false);
490
  ExecuteInternal("training");
491
492
493
494
495

  // cleanup
  if(IsParameterEnabled("cleanup"))
    {
    otbAppLogINFO(<<"Final clean-up ...");
496
497
498
499
500
501
502
503
504
    for(unsigned int i=0 ; i<polyStatTrainOutputs.size() ; i++)
      RemoveFile(polyStatTrainOutputs[i]);
    for(unsigned int i=0 ; i<polyStatValidOutputs.size() ; i++)
      RemoveFile(polyStatValidOutputs[i]);
    for(unsigned int i=0 ; i<ratesTrainOutputs.size() ; i++)
      RemoveFile(ratesTrainOutputs[i]);
    for(unsigned int i=0 ; i<ratesValidOutputs.size() ; i++)
      RemoveFile(ratesValidOutputs[i]);
    for(unsigned int i=0 ; i<sampleOutputs.size() ; i++)
505
      RemoveFile(sampleOutputs[i]);
506
    for(unsigned int i=0 ; i<sampleTrainOutputs.size() ; i++)
507
      RemoveFile(sampleTrainOutputs[i]);
508
    for(unsigned int i=0 ; i<sampleValidOutputs.size() ; i++)
509
510
      RemoveFile(sampleValidOutputs[i]);
    }
511
}
512

513
};
514

515
516
} // end namespace Wrapper
} // end namespace otb
517

518
OTB_APPLICATION_EXPORT(otb::Wrapper::TrainImagesClassifier)