otbSampleSelection.cxx 19.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20 21 22 23

#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbSamplingRateCalculator.h"
24
#include "otbOGRDataToSamplePositionFilter.h"
25
#include "otbStatisticsXMLFileReader.h"
26
#include "otbRandomSampler.h"
27 28
#include "otbGeometriesProjectionFilter.h"
#include "otbGeometriesSet.h"
29
#include "otbWrapperElevationParametersHandler.h"
30 31 32 33 34 35

namespace otb
{
namespace Wrapper
{

36 37 38 39 40 41
/** Utility function to negate std::isalnum */
bool IsNotAlphaNum(char c)
  {
  return !std::isalnum(c);
  }

42
class SampleSelection : public Application
43 44 45
{
public:
  /** Standard class typedefs. */
46
  typedef SampleSelection        Self;
47 48 49 50 51 52 53
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

54
  itkTypeMacro(SampleSelection, otb::Application);
55 56

  /** typedef */
57 58 59 60 61 62 63 64 65
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::PeriodicSampler>                           PeriodicSamplerType;
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::RandomSampler>                             RandomSamplerType;
  typedef otb::SamplingRateCalculator               RateCalculatorType;
66 67
  
  typedef std::map<std::string, unsigned long>      ClassCountMapType;
68
  typedef RateCalculatorType::MapRateType           MapRateType;
69
  typedef itk::VariableLengthVector<float> MeasurementType;
70
  typedef otb::StatisticsXMLFileReader<MeasurementType> XMLReaderType;
71

72 73 74 75
  typedef otb::GeometriesSet GeometriesType;

  typedef otb::GeometriesProjectionFilter ProjectionFilterType;

76
private:
77
  SampleSelection()
78
    {
79 80
    m_ReaderStat = XMLReaderType::New();
    m_RateCalculator = RateCalculatorType::New();
81 82
    }

83
  void DoInit() override
84
  {
85
    SetName("SampleSelection");
86 87 88
    SetDescription("Selects samples from a training vector data set.");

    // Documentation
89
    SetDocName("Sample Selection");
90 91 92 93 94
    SetDocLongDescription("The application selects a set of samples from geometries "
      "intended for training (they should have a field giving the associated "
      "class). \n\nFirst of all, the geometries must be analyzed by the PolygonClassStatistics application "
      "to compute statistics about the geometries, which are summarized in an xml file. "
      "\nThen, this xml file must be given as input to this application (parameter instats).\n\n"
95 96 97
      "The input support image and the input training vectors shall be given in "
      "parameters 'in' and 'vec' respectively. Only the sampling grid (origin, size, spacing)"
      "will be read in the input image.\n"
98 99 100 101 102
      "There are several strategies to select samples (parameter strategy) : \n\n"
      "  - smallest (default) : select the same number of sample in each class" 
      " so that the smallest one is fully sampled.\n"
      "  - constant : select the same number of samples N in each class" 
      " (with N below or equal to the size of the smallest class).\n"
103
      "  - byclass : set the required number for each class manually, with an input CSV file"
104
      " (first column is class name, second one is the required samples number).\n\n"
105 106
      "  - percent: set a target global percentage of samples to use. Class proportions will be respected. \n\n"
      "  - total: set a target total number of samples to use. Class proportions will be respected. \n\n"
107
      "There is also a choice on the sampling type to performs : \n\n"
108
      "  - periodic : select samples uniformly distributed\n"
109
      "  - random : select samples randomly distributed\n\n"
110
      "Once the strategy and type are selected, the application outputs samples positions"
111 112
      "(parameter out).\n\n"
      
113
      "The other parameters to look at are : \n\n"
114 115
      "  - layer : index specifying from which layer to pick geometries.\n"
      "  - field : set the field name containing the class.\n"
116
      "  - mask : an optional raster mask can be used to discard samples.\n"
117
      "  - outrates : allows outputting a CSV file that summarizes the sampling rates for each class.\n"
118 119
      
      "\nAs with the PolygonClassStatistics application, different types  of geometry are supported : "
120
      "polygons, lines, points. \nThe behavior of this application is different for each type of geometry : \n\n"
121 122
      "  - polygon: select points whose center is inside the polygon\n"
      "  - lines  : select points intersecting the line\n"
123
      "  - points : select closest point to the provided point");
124 125 126 127 128 129 130 131 132 133 134 135
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");

    AddDocTag(Tags::Learning);

    AddParameter(ParameterType_InputImage,  "in",   "InputImage");
    SetParameterDescription("in", "Support image that will be classified");

    AddParameter(ParameterType_InputImage,  "mask",   "InputMask");
    SetParameterDescription("mask", "Validity mask (only pixels corresponding to a mask value greater than 0 will be used for statistics)");
    MandatoryOff("mask");
136

137 138
    AddParameter(ParameterType_InputFilename, "vec", "Input vectors");
    SetParameterDescription("vec","Input geometries to analyse");
139

140 141
    AddParameter(ParameterType_OutputFilename, "out", "Output vectors");
    SetParameterDescription("out","Output resampled geometries");
142

143 144
    AddParameter(ParameterType_InputFilename, "instats", "Input Statistics");
    SetParameterDescription("instats","Input file storing statistics (XML format)");
145

146
    AddParameter(ParameterType_OutputFilename, "outrates", "Output rates");
147
    SetParameterDescription("outrates","Output rates (CSV formatted)");
148 149
    MandatoryOff("outrates");

150 151 152 153 154 155
    AddParameter(ParameterType_Choice, "sampler", "Sampler type");
    SetParameterDescription("sampler", "Type of sampling (periodic, pattern based, random)");

    AddChoice("sampler.periodic","Periodic sampler");
    SetParameterDescription("sampler.periodic","Takes samples regularly spaced");

156 157 158 159 160
    AddParameter(ParameterType_Int, "sampler.periodic.jitter","Jitter amplitude");
    SetParameterDescription("sampler.periodic.jitter", "Jitter amplitude added during sample selection (0 = no jitter)");
    SetDefaultParameterInt("sampler.periodic.jitter",0);
    MandatoryOff("sampler.periodic.jitter");

161
    AddChoice("sampler.random","Random sampler");
162
    SetParameterDescription("sampler.random","The positions to select are randomly shuffled.");
163

164
    AddParameter(ParameterType_Choice, "strategy", "Sampling strategy");
165

166 167
    AddChoice("strategy.byclass","Set samples count for each class");
    SetParameterDescription("strategy.byclass","Set samples count for each class");
168

169 170
    AddParameter(ParameterType_InputFilename, "strategy.byclass.in", "Number of samples by class");
    SetParameterDescription("strategy.byclass.in", "Number of samples by class "
171 172
      "(CSV format with class name in 1st column and required samples in the 2nd.");

173 174
    AddChoice("strategy.constant","Set the same samples counts for all classes");
    SetParameterDescription("strategy.constant","Set the same samples counts for all classes");
175

176 177
    AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes");
    SetParameterDescription("strategy.constant.nb", "Number of samples for all classes");
178

179 180 181 182 183 184 185 186
    AddChoice("strategy.percent","Use a percentage of the samples available for each class");
    SetParameterDescription("strategy.percent","Use a percentage of the samples available for each class");

    AddParameter(ParameterType_Float,"strategy.percent.p","The percentage to use");
    SetParameterDescription("strategy.percent.p","The percentage to use");
    SetMinimumParameterFloatValue("strategy.percent.p",0);
    SetMaximumParameterFloatValue("strategy.percent.p",1);
    SetDefaultParameterFloat("strategy.percent.p",0.5);
187 188 189 190 191 192 193 194

    AddChoice("strategy.total","Set the total number of samples to generate, and use class proportions.");
    SetParameterDescription("strategy.total","Set the total number of samples to generate, and use class proportions.");

    AddParameter(ParameterType_Int,"strategy.total.v","The number of samples to generate");
    SetParameterDescription("strategy.total.v","The number of samples to generate");
    SetMinimumParameterIntValue("strategy.total.v",1);
    SetDefaultParameterInt("strategy.total.v",1000);
195
    
196 197
    AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
    SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
198

199 200
    AddChoice("strategy.all","Take all samples");
    SetParameterDescription("strategy.all","Take all samples");
201

202
    // Default strategy : smallest
203
    SetParameterString("strategy","smallest");
204

205
    AddParameter(ParameterType_ListView, "field", "Field Name");
206
    SetParameterDescription("field","Name of the field carrying the class name in the input vectors.");
207
    SetListViewSingleSelectionMode("field",true);
208

209 210 211 212
    AddParameter(ParameterType_Int, "layer", "Layer Index");
    SetParameterDescription("layer", "Layer index to read in the input vector file.");
    MandatoryOff("layer");
    SetDefaultParameterInt("layer",0);
213

214 215
    ElevationParametersHandler::AddElevationParameters(this, "elev");

216 217
    AddRAMParameter();

218 219
    AddRANDParameter();

220 221 222 223 224 225
    // Doc example parameter settings
    SetDocExampleParameterValue("in", "support_image.tif");
    SetDocExampleParameterValue("vec", "variousVectors.sqlite");
    SetDocExampleParameterValue("field", "label");
    SetDocExampleParameterValue("instats","apTvClPolygonClassStatisticsOut.xml");
    SetDocExampleParameterValue("out","resampledVectors.sqlite");
226

227
    SetOfficialDocLink();
228 229
  }

230
  void DoUpdateParameters() override
231
  {
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
 if ( HasValue("vec") )
      {
      std::string vectorFile = GetParameterString("vec");
      ogr::DataSource::Pointer ogrDS =
        ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
      ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
      ogr::Feature feature = layer.ogr().GetNextFeature();

      ClearChoices("field");
      
      for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
        {
        std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
        key = item;
        std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum);
        std::transform(key.begin(), end, key.begin(), tolower);
        
        OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
        
251
        if(fieldType == OFTString || fieldType == OFTInteger || fieldType == OFTInteger64)
252 253 254 255 256 257
          {
          std::string tmpKey="field."+key.substr(0, end - key.begin());
          AddChoice(tmpKey,item);
          }
        }
      }
258 259
  }

260
  void DoExecute() override
261 262 263
    {
    // Clear state
    m_RateCalculator->ClearRates();
264

265 266
    otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev");

267 268 269 270 271 272 273 274
    // Get field name
    std::vector<int> selectedCFieldIdx = GetSelectedItems("field");
    
    if(selectedCFieldIdx.empty())
      {
      otbAppLogFATAL(<<"No field has been selected for data labelling!");
      }
    
Julien Michel's avatar
Julien Michel committed
275
    std::vector<std::string> cFieldNames = GetChoiceNames("field");  
276
    std::string fieldName = cFieldNames[selectedCFieldIdx.front()];
277 278 279 280 281
    
    m_ReaderStat->SetFileName(this->GetParameterString("instats"));
    ClassCountMapType classCount = m_ReaderStat->GetStatisticMapByName<ClassCountMapType>("samplesPerClass");
    m_RateCalculator->SetClassCount(classCount);
    
282
    switch (this->GetParameterInt("strategy"))
283
      {
284 285 286
      // byclass
      case 0:
        {
287
        otbAppLogINFO("Sampling strategy : set number of samples for each class");
288
        ClassCountMapType requiredCount = 
289
          otb::SamplingRateCalculator::ReadRequiredSamples(this->GetParameterString("strategy.byclass.in"));
290 291 292 293 294
        m_RateCalculator->SetNbOfSamplesByClass(requiredCount);
        }
      break;
      // constant
      case 1:
295 296
        {
        otbAppLogINFO("Sampling strategy : set a constant number of samples for all classes");
297
        m_RateCalculator->SetNbOfSamplesAllClasses(GetParameterInt("strategy.constant.nb"));
298
        }
299
      break;
300
      // percent
301
      case 2:
302
      {
303
      otbAppLogINFO("Sampling strategy: set a percentage of samples for each class.");
304 305 306
      m_RateCalculator->SetPercentageOfSamples(this->GetParameterFloat("strategy.percent.p"));
      }
      break;
307
      // total
308
      case 3:
309 310 311 312 313 314 315 316
      {
      otbAppLogINFO("Sampling strategy: set the total number of samples to generate, use classes proportions.");
      m_RateCalculator->SetTotalNumberOfSamples(this->GetParameterInt("strategy.total.v"));
      }
      break;

      // smallest class
      case 4:
317 318
        {
        otbAppLogINFO("Sampling strategy : fit the number of samples based on the smallest class");
319
        m_RateCalculator->SetMinimumNbOfSamplesByClass();
320 321 322
        }
      break;
      // all samples
323
      case 5:
324 325 326 327
        {
        otbAppLogINFO("Sampling strategy : take all samples");
        m_RateCalculator->SetAllSamples();
        }
328 329
      break;
      default:
330
        otbAppLogFATAL("Strategy mode unknown :"<<this->GetParameterString("strategy"));
331
      break;
332
      }
333 334
      
    if (IsParameterEnabled("outrates") && HasValue("outrates"))
335
      {
336
      m_RateCalculator->Write(this->GetParameterString("outrates"));
337
      }
338 339
    
    MapRateType rates = m_RateCalculator->GetRatesByClass();
340 341 342
    std::ostringstream oss;
    oss << " className  requiredSamples  totalSamples  rate" << std::endl;
    MapRateType::const_iterator itRates = rates.begin();
343
    unsigned int overflowCount = 0;
344 345 346
    for(; itRates != rates.end(); ++itRates)
      {
      otb::SamplingRateCalculator::TripletType tpt = itRates->second;
347 348 349 350 351 352 353
      oss << itRates->first << "\t" << tpt.Required << "\t" << tpt.Tot << "\t" << tpt.Rate;
      if (tpt.Required > tpt.Tot)
        {
        overflowCount++;
        oss << "\t[OVERFLOW]";
        }
      oss << std::endl;
354 355
      }
    otbAppLogINFO("Sampling rates : " << oss.str());
356 357 358 359 360
    if (overflowCount)
      {
      std::string plural(overflowCount>1?"s":"");
      otbAppLogWARNING(<< overflowCount << " case"<<plural<<" of overflow detected! (requested number of samples higher than total available samples)");
      }
361 362 363 364 365

    // Open input geometries
    otb::ogr::DataSource::Pointer vectors =
      otb::ogr::DataSource::New(this->GetParameterString("vec"));

366 367 368 369 370 371 372 373 374 375 376 377 378
    // Reproject geometries
    FloatVectorImageType::Pointer inputImg = this->GetParameterImage("in");
    std::string imageProjectionRef = inputImg->GetProjectionRef();
    FloatVectorImageType::ImageKeywordlistType imageKwl =
      inputImg->GetImageKeywordlist();
    std::string vectorProjectionRef =
      vectors->GetLayer(GetParameterInt("layer")).GetProjectionRef();

    otb::ogr::DataSource::Pointer reprojVector = vectors;
    GeometriesType::Pointer inputGeomSet;
    ProjectionFilterType::Pointer geometriesProjFilter;
    GeometriesType::Pointer outputGeomSet;
    bool doReproj = true;
379 380 381 382
    const OGRSpatialReference imgOGRSref = 
        OGRSpatialReference( imageProjectionRef.c_str() );
    const OGRSpatialReference vectorOGRSref = 
        OGRSpatialReference( vectorProjectionRef.c_str() );
383
    // don't reproject for these cases
384 385 386
    if (  vectorProjectionRef.empty()
       || ( imgOGRSref.IsSame( &vectorOGRSref ) )
       || ( imageProjectionRef.empty() && imageKwl.GetSize() == 0) )
387 388 389 390 391 392 393
      doReproj = false;
  
    if (doReproj)
      {
      inputGeomSet = GeometriesType::New(vectors);
      reprojVector = otb::ogr::DataSource::New();
      outputGeomSet = GeometriesType::New(reprojVector);
394
      // Filter instantiation
395 396 397 398 399 400 401 402 403 404 405 406
      geometriesProjFilter = ProjectionFilterType::New();
      geometriesProjFilter->SetInput(inputGeomSet);
      if (imageProjectionRef.empty())
        {
        geometriesProjFilter->SetOutputKeywordList(inputImg->GetImageKeywordlist()); // nec qd capteur
        }
      geometriesProjFilter->SetOutputProjectionRef(imageProjectionRef);
      geometriesProjFilter->SetOutput(outputGeomSet);
      otbAppLogINFO("Reprojecting input vectors...");
      geometriesProjFilter->Update();
      }

407 408 409 410 411
    // Create output dataset for sample positions
    otb::ogr::DataSource::Pointer outputSamples =
      otb::ogr::DataSource::New(this->GetParameterString("out"),otb::ogr::DataSource::Modes::Overwrite);
    
    switch (this->GetParameterInt("sampler"))
412
      {
413 414 415
      // periodic
      case 0:
        {
416 417 418
        PeriodicSamplerType::SamplerParameterType param;
        param.Offset = 0;
        param.MaxJitter = this->GetParameterInt("sampler.periodic.jitter");
419
        param.MaxBufferSize = 100000000UL;
420 421 422 423 424 425 426
        PeriodicSamplerType::Pointer periodicFilt = PeriodicSamplerType::New();
        periodicFilt->SetInput(this->GetParameterImage("in"));
        periodicFilt->SetOGRData(reprojVector);
        periodicFilt->SetOutputPositionContainerAndRates(outputSamples, rates);
        periodicFilt->SetFieldName(fieldName);
        periodicFilt->SetLayerIndex(this->GetParameterInt("layer"));
        periodicFilt->SetSamplerParameters(param);
427 428
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
429
          periodicFilt->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
430
          }
431 432 433
        periodicFilt->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
        AddProcess(periodicFilt->GetStreamer(),"Selecting positions with periodic sampler...");
        periodicFilt->Update();
434 435 436
        }
      break;
      // random
437
      case 1:
438
        {
439 440 441 442 443 444
        RandomSamplerType::Pointer randomFilt = RandomSamplerType::New();
        randomFilt->SetInput(this->GetParameterImage("in"));
        randomFilt->SetOGRData(reprojVector);
        randomFilt->SetOutputPositionContainerAndRates(outputSamples, rates);
        randomFilt->SetFieldName(fieldName);
        randomFilt->SetLayerIndex(this->GetParameterInt("layer"));
445 446
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
447
          randomFilt->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
448
          }
449 450 451 452 453
        randomFilt->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
        AddProcess(randomFilt->GetStreamer(),"Selecting positions with random sampler...");
        randomFilt->Update();

        randomFilt = RandomSamplerType::New();
454 455 456 457 458
        }
      break;
      default:
        otbAppLogFATAL("Sampler type unknown :"<<this->GetParameterString("sampler"));
      break;
459 460 461
      }
  }

462 463
  RateCalculatorType::Pointer m_RateCalculator;
  XMLReaderType::Pointer m_ReaderStat;
464 465 466 467 468
};

} // end of namespace Wrapper
} // end of namespace otb

469
OTB_APPLICATION_EXPORT(otb::Wrapper::SampleSelection)