otbSampleSelection.cxx 18.6 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*=========================================================================

 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbSamplingRateCalculator.h"
21
#include "otbOGRDataToSamplePositionFilter.h"
22
#include "otbStatisticsXMLFileReader.h"
23
#include "otbRandomSampler.h"
24
25
#include "otbGeometriesProjectionFilter.h"
#include "otbGeometriesSet.h"
26
27
28
29
30
31

namespace otb
{
namespace Wrapper
{

32
33
34
35
36
37
/** Utility function to negate std::isalnum */
bool IsNotAlphaNum(char c)
  {
  return !std::isalnum(c);
  }

38
class SampleSelection : public Application
39
40
41
{
public:
  /** Standard class typedefs. */
42
  typedef SampleSelection        Self;
43
44
45
46
47
48
49
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

50
  itkTypeMacro(SampleSelection, otb::Application);
51
52

  /** typedef */
53
54
55
56
57
58
59
60
61
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::PeriodicSampler>                           PeriodicSamplerType;
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::RandomSampler>                             RandomSamplerType;
  typedef otb::SamplingRateCalculator               RateCalculatorType;
62
63
  
  typedef std::map<std::string, unsigned long>      ClassCountMapType;
64
  typedef RateCalculatorType::MapRateType           MapRateType;
65
  typedef itk::VariableLengthVector<float> MeasurementType;
66
  typedef otb::StatisticsXMLFileReader<MeasurementType> XMLReaderType;
67

68
69
70
71
  typedef otb::GeometriesSet GeometriesType;

  typedef otb::GeometriesProjectionFilter ProjectionFilterType;

72
private:
73
  SampleSelection()
74
    {
75
76
77
78
    m_Periodic = PeriodicSamplerType::New();
    m_Random = RandomSamplerType::New();
    m_ReaderStat = XMLReaderType::New();
    m_RateCalculator = RateCalculatorType::New();
79
80
81
82
    }

  void DoInit()
  {
83
    SetName("SampleSelection");
84
85
86
    SetDescription("Selects samples from a training vector data set.");

    // Documentation
87
    SetDocName("Sample Selection");
88
89
90
91
92
    SetDocLongDescription("The application selects a set of samples from geometries "
      "intended for training (they should have a field giving the associated "
      "class). \n\nFirst of all, the geometries must be analyzed by the PolygonClassStatistics application "
      "to compute statistics about the geometries, which are summarized in an xml file. "
      "\nThen, this xml file must be given as input to this application (parameter instats).\n\n"
93
94
95
      "The input support image and the input training vectors shall be given in "
      "parameters 'in' and 'vec' respectively. Only the sampling grid (origin, size, spacing)"
      "will be read in the input image.\n"
96
97
98
99
100
      "There are several strategies to select samples (parameter strategy) : \n\n"
      "  - smallest (default) : select the same number of sample in each class" 
      " so that the smallest one is fully sampled.\n"
      "  - constant : select the same number of samples N in each class" 
      " (with N below or equal to the size of the smallest class).\n"
101
      "  - byclass : set the required number for each class manually, with an input CSV file"
102
      " (first column is class name, second one is the required samples number).\n\n"
103
104
      "  - percent: set a target global percentage of samples to use. Class proportions will be respected. \n\n"
      "  - total: set a target total number of samples to use. Class proportions will be respected. \n\n"
105
      "There is also a choice on the sampling type to performs : \n\n"
106
      "  - periodic : select samples uniformly distributed\n"
107
      "  - random : select samples randomly distributed\n\n"
108
      "Once the strategy and type are selected, the application outputs samples positions"
109
110
      "(parameter out).\n\n"
      
111
      "The other parameters to look at are : \n\n"
112
113
      "  - layer : index specifying from which layer to pick geometries.\n"
      "  - field : set the field name containing the class.\n"
114
      "  - mask : an optional raster mask can be used to discard samples.\n"
115
      "  - outrates : allows outputting a CSV file that summarizes the sampling rates for each class.\n"
116
117
      
      "\nAs with the PolygonClassStatistics application, different types  of geometry are supported : "
118
      "polygons, lines, points. \nThe behavior of this application is different for each type of geometry : \n\n"
119
120
      "  - polygon: select points whose center is inside the polygon\n"
      "  - lines  : select points intersecting the line\n"
Christophe Palmann's avatar
Christophe Palmann committed
121
      "  - points : select closest point to the provided point\n");
122
123
124
125
126
127
128
129
130
131
132
133
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");

    AddDocTag(Tags::Learning);

    AddParameter(ParameterType_InputImage,  "in",   "InputImage");
    SetParameterDescription("in", "Support image that will be classified");

    AddParameter(ParameterType_InputImage,  "mask",   "InputMask");
    SetParameterDescription("mask", "Validity mask (only pixels corresponding to a mask value greater than 0 will be used for statistics)");
    MandatoryOff("mask");
134

135
136
    AddParameter(ParameterType_InputFilename, "vec", "Input vectors");
    SetParameterDescription("vec","Input geometries to analyse");
137

138
139
    AddParameter(ParameterType_OutputFilename, "out", "Output vectors");
    SetParameterDescription("out","Output resampled geometries");
140

141
142
    AddParameter(ParameterType_InputFilename, "instats", "Input Statistics");
    SetParameterDescription("instats","Input file storing statistics (XML format)");
143

144
    AddParameter(ParameterType_OutputFilename, "outrates", "Output rates");
145
    SetParameterDescription("outrates","Output rates (CSV formatted)");
146
147
    MandatoryOff("outrates");

148
149
150
151
152
153
    AddParameter(ParameterType_Choice, "sampler", "Sampler type");
    SetParameterDescription("sampler", "Type of sampling (periodic, pattern based, random)");

    AddChoice("sampler.periodic","Periodic sampler");
    SetParameterDescription("sampler.periodic","Takes samples regularly spaced");

154
155
156
157
158
    AddParameter(ParameterType_Int, "sampler.periodic.jitter","Jitter amplitude");
    SetParameterDescription("sampler.periodic.jitter", "Jitter amplitude added during sample selection (0 = no jitter)");
    SetDefaultParameterInt("sampler.periodic.jitter",0);
    MandatoryOff("sampler.periodic.jitter");

159
    AddChoice("sampler.random","Random sampler");
160
    SetParameterDescription("sampler.random","The positions to select are randomly shuffled.");
161

162
    AddParameter(ParameterType_Choice, "strategy", "Sampling strategy");
163

164
165
    AddChoice("strategy.byclass","Set samples count for each class");
    SetParameterDescription("strategy.byclass","Set samples count for each class");
166

167
168
    AddParameter(ParameterType_InputFilename, "strategy.byclass.in", "Number of samples by class");
    SetParameterDescription("strategy.byclass.in", "Number of samples by class "
169
170
      "(CSV format with class name in 1st column and required samples in the 2nd.");

171
172
    AddChoice("strategy.constant","Set the same samples counts for all classes");
    SetParameterDescription("strategy.constant","Set the same samples counts for all classes");
173

174
175
    AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes");
    SetParameterDescription("strategy.constant.nb", "Number of samples for all classes");
176

177
178
179
180
181
182
183
184
    AddChoice("strategy.percent","Use a percentage of the samples available for each class");
    SetParameterDescription("strategy.percent","Use a percentage of the samples available for each class");

    AddParameter(ParameterType_Float,"strategy.percent.p","The percentage to use");
    SetParameterDescription("strategy.percent.p","The percentage to use");
    SetMinimumParameterFloatValue("strategy.percent.p",0);
    SetMaximumParameterFloatValue("strategy.percent.p",1);
    SetDefaultParameterFloat("strategy.percent.p",0.5);
185
186
187
188
189
190
191
192

    AddChoice("strategy.total","Set the total number of samples to generate, and use class proportions.");
    SetParameterDescription("strategy.total","Set the total number of samples to generate, and use class proportions.");

    AddParameter(ParameterType_Int,"strategy.total.v","The number of samples to generate");
    SetParameterDescription("strategy.total.v","The number of samples to generate");
    SetMinimumParameterIntValue("strategy.total.v",1);
    SetDefaultParameterInt("strategy.total.v",1000);
193
    
194
195
    AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
    SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
196

197
198
    AddChoice("strategy.all","Take all samples");
    SetParameterDescription("strategy.all","Take all samples");
199

200
    // Default strategy : smallest
201
    SetParameterString("strategy","smallest");
202

203
    AddParameter(ParameterType_ListView, "field", "Field Name");
204
    SetParameterDescription("field","Name of the field carrying the class name in the input vectors.");
205
    SetListViewSingleSelectionMode("field",true);
206

207
208
209
210
    AddParameter(ParameterType_Int, "layer", "Layer Index");
    SetParameterDescription("layer", "Layer index to read in the input vector file.");
    MandatoryOff("layer");
    SetDefaultParameterInt("layer",0);
211

212
213
    AddRAMParameter();

214
215
    AddRANDParameter();

216
217
218
219
220
221
222
223
224
225
    // Doc example parameter settings
    SetDocExampleParameterValue("in", "support_image.tif");
    SetDocExampleParameterValue("vec", "variousVectors.sqlite");
    SetDocExampleParameterValue("field", "label");
    SetDocExampleParameterValue("instats","apTvClPolygonClassStatisticsOut.xml");
    SetDocExampleParameterValue("out","resampledVectors.sqlite");
  }

  void DoUpdateParameters()
  {
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
 if ( HasValue("vec") )
      {
      std::string vectorFile = GetParameterString("vec");
      ogr::DataSource::Pointer ogrDS =
        ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
      ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
      ogr::Feature feature = layer.ogr().GetNextFeature();

      ClearChoices("field");
      
      for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
        {
        std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
        key = item;
        std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum);
        std::transform(key.begin(), end, key.begin(), tolower);
        
        OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
        
245
        if(fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType))
246
247
248
249
250
251
          {
          std::string tmpKey="field."+key.substr(0, end - key.begin());
          AddChoice(tmpKey,item);
          }
        }
      }
252
253
  }

254
255
256
257
258
259
  void DoExecute()
    {
    // Clear state
    m_RateCalculator->ClearRates();
    m_Periodic->GetFilter()->ClearOutputs();
    m_Random->GetFilter()->ClearOutputs();
260
261
262
263

    // Setup ram
    m_Periodic->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram"));
    m_Random->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram"));
264
265
266
267
268
269
270
271
272
273
274

    // Get field name
    std::vector<int> selectedCFieldIdx = GetSelectedItems("field");
    
    if(selectedCFieldIdx.empty())
      {
      otbAppLogFATAL(<<"No field has been selected for data labelling!");
      }
    
    std::vector<std::string> cFieldNames = GetChoiceNames("cfield");  
    std::string fieldName = cFieldNames[selectedCFieldIdx.front()];
275
276
277
278
279
    
    m_ReaderStat->SetFileName(this->GetParameterString("instats"));
    ClassCountMapType classCount = m_ReaderStat->GetStatisticMapByName<ClassCountMapType>("samplesPerClass");
    m_RateCalculator->SetClassCount(classCount);
    
280
    switch (this->GetParameterInt("strategy"))
281
      {
282
283
284
      // byclass
      case 0:
        {
285
        otbAppLogINFO("Sampling strategy : set number of samples for each class");
286
        ClassCountMapType requiredCount = 
287
          otb::SamplingRateCalculator::ReadRequiredSamples(this->GetParameterString("strategy.byclass.in"));
288
289
290
291
292
        m_RateCalculator->SetNbOfSamplesByClass(requiredCount);
        }
      break;
      // constant
      case 1:
293
294
        {
        otbAppLogINFO("Sampling strategy : set a constant number of samples for all classes");
295
        m_RateCalculator->SetNbOfSamplesAllClasses(GetParameterInt("strategy.constant.nb"));
296
        }
297
      break;
298
      // percent
299
      case 2:
300
      {
301
      otbAppLogINFO("Sampling strategy: set a percentage of samples for each class.");
302
303
304
      m_RateCalculator->SetPercentageOfSamples(this->GetParameterFloat("strategy.percent.p"));
      }
      break;
305
      // total
306
      case 3:
307
308
309
310
311
312
313
314
      {
      otbAppLogINFO("Sampling strategy: set the total number of samples to generate, use classes proportions.");
      m_RateCalculator->SetTotalNumberOfSamples(this->GetParameterInt("strategy.total.v"));
      }
      break;

      // smallest class
      case 4:
315
316
        {
        otbAppLogINFO("Sampling strategy : fit the number of samples based on the smallest class");
317
        m_RateCalculator->SetMinimumNbOfSamplesByClass();
318
319
320
        }
      break;
      // all samples
321
      case 5:
322
323
324
325
        {
        otbAppLogINFO("Sampling strategy : take all samples");
        m_RateCalculator->SetAllSamples();
        }
326
327
      break;
      default:
328
        otbAppLogFATAL("Strategy mode unknown :"<<this->GetParameterString("strategy"));
329
      break;
330
      }
331
332
      
    if (IsParameterEnabled("outrates") && HasValue("outrates"))
333
      {
334
      m_RateCalculator->Write(this->GetParameterString("outrates"));
335
      }
336
337
    
    MapRateType rates = m_RateCalculator->GetRatesByClass();
338
339
340
    std::ostringstream oss;
    oss << " className  requiredSamples  totalSamples  rate" << std::endl;
    MapRateType::const_iterator itRates = rates.begin();
341
    unsigned int overflowCount = 0;
342
343
344
    for(; itRates != rates.end(); ++itRates)
      {
      otb::SamplingRateCalculator::TripletType tpt = itRates->second;
345
346
347
348
349
350
351
      oss << itRates->first << "\t" << tpt.Required << "\t" << tpt.Tot << "\t" << tpt.Rate;
      if (tpt.Required > tpt.Tot)
        {
        overflowCount++;
        oss << "\t[OVERFLOW]";
        }
      oss << std::endl;
352
353
      }
    otbAppLogINFO("Sampling rates : " << oss.str());
354
355
356
357
358
    if (overflowCount)
      {
      std::string plural(overflowCount>1?"s":"");
      otbAppLogWARNING(<< overflowCount << " case"<<plural<<" of overflow detected! (requested number of samples higher than total available samples)");
      }
359
360
361
362
363

    // Open input geometries
    otb::ogr::DataSource::Pointer vectors =
      otb::ogr::DataSource::New(this->GetParameterString("vec"));

364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    // Reproject geometries
    FloatVectorImageType::Pointer inputImg = this->GetParameterImage("in");
    std::string imageProjectionRef = inputImg->GetProjectionRef();
    FloatVectorImageType::ImageKeywordlistType imageKwl =
      inputImg->GetImageKeywordlist();
    std::string vectorProjectionRef =
      vectors->GetLayer(GetParameterInt("layer")).GetProjectionRef();

    otb::ogr::DataSource::Pointer reprojVector = vectors;
    GeometriesType::Pointer inputGeomSet;
    ProjectionFilterType::Pointer geometriesProjFilter;
    GeometriesType::Pointer outputGeomSet;
    bool doReproj = true;
    // don't reproject for these cases
    if (vectorProjectionRef.empty() ||
        (imageProjectionRef == vectorProjectionRef) ||
        (imageProjectionRef.empty() && imageKwl.GetSize() == 0))
      doReproj = false;
  
    if (doReproj)
      {
      inputGeomSet = GeometriesType::New(vectors);
      reprojVector = otb::ogr::DataSource::New();
      outputGeomSet = GeometriesType::New(reprojVector);
388
      // Filter instantiation
389
390
391
392
393
394
395
396
397
398
399
400
      geometriesProjFilter = ProjectionFilterType::New();
      geometriesProjFilter->SetInput(inputGeomSet);
      if (imageProjectionRef.empty())
        {
        geometriesProjFilter->SetOutputKeywordList(inputImg->GetImageKeywordlist()); // nec qd capteur
        }
      geometriesProjFilter->SetOutputProjectionRef(imageProjectionRef);
      geometriesProjFilter->SetOutput(outputGeomSet);
      otbAppLogINFO("Reprojecting input vectors...");
      geometriesProjFilter->Update();
      }

401
402
403
404
405
    // Create output dataset for sample positions
    otb::ogr::DataSource::Pointer outputSamples =
      otb::ogr::DataSource::New(this->GetParameterString("out"),otb::ogr::DataSource::Modes::Overwrite);
    
    switch (this->GetParameterInt("sampler"))
406
      {
407
408
409
      // periodic
      case 0:
        {
410
411
412
413
        PeriodicSamplerType::SamplerParameterType param;
        param.Offset = 0;
        param.MaxJitter = this->GetParameterInt("sampler.periodic.jitter");

414
        m_Periodic->SetInput(this->GetParameterImage("in"));
415
        m_Periodic->SetOGRData(reprojVector);
416
        m_Periodic->SetOutputPositionContainerAndRates(outputSamples, rates);
417
        m_Periodic->SetFieldName(fieldName);
418
        m_Periodic->SetLayerIndex(this->GetParameterInt("layer"));
419
        m_Periodic->SetSamplerParameters(param);
420
421
422
423
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
          m_Periodic->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
          }
424
        m_Periodic->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
425
        AddProcess(m_Periodic->GetStreamer(),"Selecting positions with periodic sampler...");
426
427
428
429
        m_Periodic->Update();
        }
      break;
      // random
430
      case 1:
431
432
        {
        m_Random->SetInput(this->GetParameterImage("in"));
433
        m_Random->SetOGRData(reprojVector);
434
        m_Random->SetOutputPositionContainerAndRates(outputSamples, rates);
435
        m_Random->SetFieldName(fieldName);
436
437
438
439
440
        m_Random->SetLayerIndex(this->GetParameterInt("layer"));
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
          m_Random->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
          }
441
        m_Random->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
442
        AddProcess(m_Random->GetStreamer(),"Selecting positions with random sampler...");
443
444
445
446
447
448
        m_Random->Update();
        }
      break;
      default:
        otbAppLogFATAL("Sampler type unknown :"<<this->GetParameterString("sampler"));
      break;
449
450
451
      }
  }

452
453
454
455
456
457
  RateCalculatorType::Pointer m_RateCalculator;
  
  PeriodicSamplerType::Pointer m_Periodic;
  RandomSamplerType::Pointer m_Random;
  
  XMLReaderType::Pointer m_ReaderStat;
458
459
460
461
462
};

} // end of namespace Wrapper
} // end of namespace otb

463
OTB_APPLICATION_EXPORT(otb::Wrapper::SampleSelection)