otbSampleSelection.cxx 18.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*=========================================================================

 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbSamplingRateCalculator.h"
21
#include "otbOGRDataToSamplePositionFilter.h"
22
#include "otbStatisticsXMLFileReader.h"
23
#include "otbRandomSampler.h"
24
25
#include "otbGeometriesProjectionFilter.h"
#include "otbGeometriesSet.h"
26
#include "otbWrapperElevationParametersHandler.h"
27
28
29
30
31
32

namespace otb
{
namespace Wrapper
{

33
34
35
36
37
38
/** Utility function to negate std::isalnum */
bool IsNotAlphaNum(char c)
  {
  return !std::isalnum(c);
  }

39
class SampleSelection : public Application
40
41
42
{
public:
  /** Standard class typedefs. */
43
  typedef SampleSelection        Self;
44
45
46
47
48
49
50
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

51
  itkTypeMacro(SampleSelection, otb::Application);
52
53

  /** typedef */
54
55
56
57
58
59
60
61
62
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::PeriodicSampler>                           PeriodicSamplerType;
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::RandomSampler>                             RandomSamplerType;
  typedef otb::SamplingRateCalculator               RateCalculatorType;
63
64
  
  typedef std::map<std::string, unsigned long>      ClassCountMapType;
65
  typedef RateCalculatorType::MapRateType           MapRateType;
66
  typedef itk::VariableLengthVector<float> MeasurementType;
67
  typedef otb::StatisticsXMLFileReader<MeasurementType> XMLReaderType;
68

69
70
71
72
  typedef otb::GeometriesSet GeometriesType;

  typedef otb::GeometriesProjectionFilter ProjectionFilterType;

73
private:
74
  SampleSelection()
75
    {
76
77
    m_ReaderStat = XMLReaderType::New();
    m_RateCalculator = RateCalculatorType::New();
78
79
80
81
    }

  void DoInit()
  {
82
    SetName("SampleSelection");
83
84
85
    SetDescription("Selects samples from a training vector data set.");

    // Documentation
86
    SetDocName("Sample Selection");
87
88
89
90
91
    SetDocLongDescription("The application selects a set of samples from geometries "
      "intended for training (they should have a field giving the associated "
      "class). \n\nFirst of all, the geometries must be analyzed by the PolygonClassStatistics application "
      "to compute statistics about the geometries, which are summarized in an xml file. "
      "\nThen, this xml file must be given as input to this application (parameter instats).\n\n"
92
93
94
      "The input support image and the input training vectors shall be given in "
      "parameters 'in' and 'vec' respectively. Only the sampling grid (origin, size, spacing)"
      "will be read in the input image.\n"
95
96
97
98
99
      "There are several strategies to select samples (parameter strategy) : \n\n"
      "  - smallest (default) : select the same number of sample in each class" 
      " so that the smallest one is fully sampled.\n"
      "  - constant : select the same number of samples N in each class" 
      " (with N below or equal to the size of the smallest class).\n"
100
      "  - byclass : set the required number for each class manually, with an input CSV file"
101
      " (first column is class name, second one is the required samples number).\n\n"
102
103
      "  - percent: set a target global percentage of samples to use. Class proportions will be respected. \n\n"
      "  - total: set a target total number of samples to use. Class proportions will be respected. \n\n"
104
      "There is also a choice on the sampling type to performs : \n\n"
105
      "  - periodic : select samples uniformly distributed\n"
106
      "  - random : select samples randomly distributed\n\n"
107
      "Once the strategy and type are selected, the application outputs samples positions"
108
109
      "(parameter out).\n\n"
      
110
      "The other parameters to look at are : \n\n"
111
112
      "  - layer : index specifying from which layer to pick geometries.\n"
      "  - field : set the field name containing the class.\n"
113
      "  - mask : an optional raster mask can be used to discard samples.\n"
114
      "  - outrates : allows outputting a CSV file that summarizes the sampling rates for each class.\n"
115
116
      
      "\nAs with the PolygonClassStatistics application, different types  of geometry are supported : "
117
      "polygons, lines, points. \nThe behavior of this application is different for each type of geometry : \n\n"
118
119
      "  - polygon: select points whose center is inside the polygon\n"
      "  - lines  : select points intersecting the line\n"
Christophe Palmann's avatar
Christophe Palmann committed
120
      "  - points : select closest point to the provided point\n");
121
122
123
124
125
126
127
128
129
130
131
132
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");

    AddDocTag(Tags::Learning);

    AddParameter(ParameterType_InputImage,  "in",   "InputImage");
    SetParameterDescription("in", "Support image that will be classified");

    AddParameter(ParameterType_InputImage,  "mask",   "InputMask");
    SetParameterDescription("mask", "Validity mask (only pixels corresponding to a mask value greater than 0 will be used for statistics)");
    MandatoryOff("mask");
133

134
135
    AddParameter(ParameterType_InputFilename, "vec", "Input vectors");
    SetParameterDescription("vec","Input geometries to analyse");
136

137
138
    AddParameter(ParameterType_OutputFilename, "out", "Output vectors");
    SetParameterDescription("out","Output resampled geometries");
139

140
141
    AddParameter(ParameterType_InputFilename, "instats", "Input Statistics");
    SetParameterDescription("instats","Input file storing statistics (XML format)");
142

143
    AddParameter(ParameterType_OutputFilename, "outrates", "Output rates");
144
    SetParameterDescription("outrates","Output rates (CSV formatted)");
145
146
    MandatoryOff("outrates");

147
148
149
150
151
152
    AddParameter(ParameterType_Choice, "sampler", "Sampler type");
    SetParameterDescription("sampler", "Type of sampling (periodic, pattern based, random)");

    AddChoice("sampler.periodic","Periodic sampler");
    SetParameterDescription("sampler.periodic","Takes samples regularly spaced");

153
154
155
156
157
    AddParameter(ParameterType_Int, "sampler.periodic.jitter","Jitter amplitude");
    SetParameterDescription("sampler.periodic.jitter", "Jitter amplitude added during sample selection (0 = no jitter)");
    SetDefaultParameterInt("sampler.periodic.jitter",0);
    MandatoryOff("sampler.periodic.jitter");

158
    AddChoice("sampler.random","Random sampler");
159
    SetParameterDescription("sampler.random","The positions to select are randomly shuffled.");
160

161
    AddParameter(ParameterType_Choice, "strategy", "Sampling strategy");
162

163
164
    AddChoice("strategy.byclass","Set samples count for each class");
    SetParameterDescription("strategy.byclass","Set samples count for each class");
165

166
167
    AddParameter(ParameterType_InputFilename, "strategy.byclass.in", "Number of samples by class");
    SetParameterDescription("strategy.byclass.in", "Number of samples by class "
168
169
      "(CSV format with class name in 1st column and required samples in the 2nd.");

170
171
    AddChoice("strategy.constant","Set the same samples counts for all classes");
    SetParameterDescription("strategy.constant","Set the same samples counts for all classes");
172

173
174
    AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes");
    SetParameterDescription("strategy.constant.nb", "Number of samples for all classes");
175

176
177
178
179
180
181
182
183
    AddChoice("strategy.percent","Use a percentage of the samples available for each class");
    SetParameterDescription("strategy.percent","Use a percentage of the samples available for each class");

    AddParameter(ParameterType_Float,"strategy.percent.p","The percentage to use");
    SetParameterDescription("strategy.percent.p","The percentage to use");
    SetMinimumParameterFloatValue("strategy.percent.p",0);
    SetMaximumParameterFloatValue("strategy.percent.p",1);
    SetDefaultParameterFloat("strategy.percent.p",0.5);
184
185
186
187
188
189
190
191

    AddChoice("strategy.total","Set the total number of samples to generate, and use class proportions.");
    SetParameterDescription("strategy.total","Set the total number of samples to generate, and use class proportions.");

    AddParameter(ParameterType_Int,"strategy.total.v","The number of samples to generate");
    SetParameterDescription("strategy.total.v","The number of samples to generate");
    SetMinimumParameterIntValue("strategy.total.v",1);
    SetDefaultParameterInt("strategy.total.v",1000);
192
    
193
194
    AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
    SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
195

196
197
    AddChoice("strategy.all","Take all samples");
    SetParameterDescription("strategy.all","Take all samples");
198

199
    // Default strategy : smallest
200
    SetParameterString("strategy","smallest", false);
201

202
    AddParameter(ParameterType_ListView, "field", "Field Name");
203
    SetParameterDescription("field","Name of the field carrying the class name in the input vectors.");
204
    SetListViewSingleSelectionMode("field",true);
205

206
207
208
209
    AddParameter(ParameterType_Int, "layer", "Layer Index");
    SetParameterDescription("layer", "Layer index to read in the input vector file.");
    MandatoryOff("layer");
    SetDefaultParameterInt("layer",0);
210

211
212
    ElevationParametersHandler::AddElevationParameters(this, "elev");

213
214
    AddRAMParameter();

215
216
    AddRANDParameter();

217
218
219
220
221
222
223
224
225
226
    // Doc example parameter settings
    SetDocExampleParameterValue("in", "support_image.tif");
    SetDocExampleParameterValue("vec", "variousVectors.sqlite");
    SetDocExampleParameterValue("field", "label");
    SetDocExampleParameterValue("instats","apTvClPolygonClassStatisticsOut.xml");
    SetDocExampleParameterValue("out","resampledVectors.sqlite");
  }

  void DoUpdateParameters()
  {
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
 if ( HasValue("vec") )
      {
      std::string vectorFile = GetParameterString("vec");
      ogr::DataSource::Pointer ogrDS =
        ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
      ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
      ogr::Feature feature = layer.ogr().GetNextFeature();

      ClearChoices("field");
      
      for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
        {
        std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
        key = item;
        std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum);
        std::transform(key.begin(), end, key.begin(), tolower);
        
        OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
        
246
        if(fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType))
247
248
249
250
251
252
          {
          std::string tmpKey="field."+key.substr(0, end - key.begin());
          AddChoice(tmpKey,item);
          }
        }
      }
253
254
  }

255
256
257
258
  void DoExecute()
    {
    // Clear state
    m_RateCalculator->ClearRates();
259

260
261
    otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev");

262
263
264
265
266
267
268
269
    // Get field name
    std::vector<int> selectedCFieldIdx = GetSelectedItems("field");
    
    if(selectedCFieldIdx.empty())
      {
      otbAppLogFATAL(<<"No field has been selected for data labelling!");
      }
    
Julien Michel's avatar
Julien Michel committed
270
    std::vector<std::string> cFieldNames = GetChoiceNames("field");  
271
    std::string fieldName = cFieldNames[selectedCFieldIdx.front()];
272
273
274
275
276
    
    m_ReaderStat->SetFileName(this->GetParameterString("instats"));
    ClassCountMapType classCount = m_ReaderStat->GetStatisticMapByName<ClassCountMapType>("samplesPerClass");
    m_RateCalculator->SetClassCount(classCount);
    
277
    switch (this->GetParameterInt("strategy"))
278
      {
279
280
281
      // byclass
      case 0:
        {
282
        otbAppLogINFO("Sampling strategy : set number of samples for each class");
283
        ClassCountMapType requiredCount = 
284
          otb::SamplingRateCalculator::ReadRequiredSamples(this->GetParameterString("strategy.byclass.in"));
285
286
287
288
289
        m_RateCalculator->SetNbOfSamplesByClass(requiredCount);
        }
      break;
      // constant
      case 1:
290
291
        {
        otbAppLogINFO("Sampling strategy : set a constant number of samples for all classes");
292
        m_RateCalculator->SetNbOfSamplesAllClasses(GetParameterInt("strategy.constant.nb"));
293
        }
294
      break;
295
      // percent
296
      case 2:
297
      {
298
      otbAppLogINFO("Sampling strategy: set a percentage of samples for each class.");
299
300
301
      m_RateCalculator->SetPercentageOfSamples(this->GetParameterFloat("strategy.percent.p"));
      }
      break;
302
      // total
303
      case 3:
304
305
306
307
308
309
310
311
      {
      otbAppLogINFO("Sampling strategy: set the total number of samples to generate, use classes proportions.");
      m_RateCalculator->SetTotalNumberOfSamples(this->GetParameterInt("strategy.total.v"));
      }
      break;

      // smallest class
      case 4:
312
313
        {
        otbAppLogINFO("Sampling strategy : fit the number of samples based on the smallest class");
314
        m_RateCalculator->SetMinimumNbOfSamplesByClass();
315
316
317
        }
      break;
      // all samples
318
      case 5:
319
320
321
322
        {
        otbAppLogINFO("Sampling strategy : take all samples");
        m_RateCalculator->SetAllSamples();
        }
323
324
      break;
      default:
325
        otbAppLogFATAL("Strategy mode unknown :"<<this->GetParameterString("strategy"));
326
      break;
327
      }
328
329
      
    if (IsParameterEnabled("outrates") && HasValue("outrates"))
330
      {
331
      m_RateCalculator->Write(this->GetParameterString("outrates"));
332
      }
333
334
    
    MapRateType rates = m_RateCalculator->GetRatesByClass();
335
336
337
    std::ostringstream oss;
    oss << " className  requiredSamples  totalSamples  rate" << std::endl;
    MapRateType::const_iterator itRates = rates.begin();
338
    unsigned int overflowCount = 0;
339
340
341
    for(; itRates != rates.end(); ++itRates)
      {
      otb::SamplingRateCalculator::TripletType tpt = itRates->second;
342
343
344
345
346
347
348
      oss << itRates->first << "\t" << tpt.Required << "\t" << tpt.Tot << "\t" << tpt.Rate;
      if (tpt.Required > tpt.Tot)
        {
        overflowCount++;
        oss << "\t[OVERFLOW]";
        }
      oss << std::endl;
349
350
      }
    otbAppLogINFO("Sampling rates : " << oss.str());
351
352
353
354
355
    if (overflowCount)
      {
      std::string plural(overflowCount>1?"s":"");
      otbAppLogWARNING(<< overflowCount << " case"<<plural<<" of overflow detected! (requested number of samples higher than total available samples)");
      }
356
357
358
359
360

    // Open input geometries
    otb::ogr::DataSource::Pointer vectors =
      otb::ogr::DataSource::New(this->GetParameterString("vec"));

361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    // Reproject geometries
    FloatVectorImageType::Pointer inputImg = this->GetParameterImage("in");
    std::string imageProjectionRef = inputImg->GetProjectionRef();
    FloatVectorImageType::ImageKeywordlistType imageKwl =
      inputImg->GetImageKeywordlist();
    std::string vectorProjectionRef =
      vectors->GetLayer(GetParameterInt("layer")).GetProjectionRef();

    otb::ogr::DataSource::Pointer reprojVector = vectors;
    GeometriesType::Pointer inputGeomSet;
    ProjectionFilterType::Pointer geometriesProjFilter;
    GeometriesType::Pointer outputGeomSet;
    bool doReproj = true;
    // don't reproject for these cases
    if (vectorProjectionRef.empty() ||
        (imageProjectionRef == vectorProjectionRef) ||
        (imageProjectionRef.empty() && imageKwl.GetSize() == 0))
      doReproj = false;
  
    if (doReproj)
      {
      inputGeomSet = GeometriesType::New(vectors);
      reprojVector = otb::ogr::DataSource::New();
      outputGeomSet = GeometriesType::New(reprojVector);
385
      // Filter instantiation
386
387
388
389
390
391
392
393
394
395
396
397
      geometriesProjFilter = ProjectionFilterType::New();
      geometriesProjFilter->SetInput(inputGeomSet);
      if (imageProjectionRef.empty())
        {
        geometriesProjFilter->SetOutputKeywordList(inputImg->GetImageKeywordlist()); // nec qd capteur
        }
      geometriesProjFilter->SetOutputProjectionRef(imageProjectionRef);
      geometriesProjFilter->SetOutput(outputGeomSet);
      otbAppLogINFO("Reprojecting input vectors...");
      geometriesProjFilter->Update();
      }

398
399
400
401
402
    // Create output dataset for sample positions
    otb::ogr::DataSource::Pointer outputSamples =
      otb::ogr::DataSource::New(this->GetParameterString("out"),otb::ogr::DataSource::Modes::Overwrite);
    
    switch (this->GetParameterInt("sampler"))
403
      {
404
405
406
      // periodic
      case 0:
        {
407
408
409
        PeriodicSamplerType::SamplerParameterType param;
        param.Offset = 0;
        param.MaxJitter = this->GetParameterInt("sampler.periodic.jitter");
410
        param.MaxBufferSize = 100000000UL;
411
412
413
414
415
416
417
        PeriodicSamplerType::Pointer periodicFilt = PeriodicSamplerType::New();
        periodicFilt->SetInput(this->GetParameterImage("in"));
        periodicFilt->SetOGRData(reprojVector);
        periodicFilt->SetOutputPositionContainerAndRates(outputSamples, rates);
        periodicFilt->SetFieldName(fieldName);
        periodicFilt->SetLayerIndex(this->GetParameterInt("layer"));
        periodicFilt->SetSamplerParameters(param);
418
419
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
420
          periodicFilt->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
421
          }
422
423
424
        periodicFilt->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
        AddProcess(periodicFilt->GetStreamer(),"Selecting positions with periodic sampler...");
        periodicFilt->Update();
425
426
427
        }
      break;
      // random
428
      case 1:
429
        {
430
431
432
433
434
435
        RandomSamplerType::Pointer randomFilt = RandomSamplerType::New();
        randomFilt->SetInput(this->GetParameterImage("in"));
        randomFilt->SetOGRData(reprojVector);
        randomFilt->SetOutputPositionContainerAndRates(outputSamples, rates);
        randomFilt->SetFieldName(fieldName);
        randomFilt->SetLayerIndex(this->GetParameterInt("layer"));
436
437
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
438
          randomFilt->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
439
          }
440
441
442
443
444
        randomFilt->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
        AddProcess(randomFilt->GetStreamer(),"Selecting positions with random sampler...");
        randomFilt->Update();

        randomFilt = RandomSamplerType::New();
445
446
447
448
449
        }
      break;
      default:
        otbAppLogFATAL("Sampler type unknown :"<<this->GetParameterString("sampler"));
      break;
450
451
452
      }
  }

453
454
  RateCalculatorType::Pointer m_RateCalculator;
  XMLReaderType::Pointer m_ReaderStat;
455
456
457
458
459
};

} // end of namespace Wrapper
} // end of namespace otb

460
OTB_APPLICATION_EXPORT(otb::Wrapper::SampleSelection)