otbSampleSelection.cxx 18.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
/*=========================================================================

 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbSamplingRateCalculator.h"
21
#include "otbOGRDataToSamplePositionFilter.h"
22
#include "otbStatisticsXMLFileReader.h"
23
#include "otbRandomSampler.h"
24 25
#include "otbGeometriesProjectionFilter.h"
#include "otbGeometriesSet.h"
26 27 28 29 30 31

namespace otb
{
namespace Wrapper
{

32 33 34 35 36 37
/** Utility function to negate std::isalnum */
bool IsNotAlphaNum(char c)
  {
  return !std::isalnum(c);
  }

38
class SampleSelection : public Application
39 40 41
{
public:
  /** Standard class typedefs. */
42
  typedef SampleSelection        Self;
43 44 45 46 47 48 49
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

50
  itkTypeMacro(SampleSelection, otb::Application);
51 52

  /** typedef */
53 54 55 56 57 58 59 60 61
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::PeriodicSampler>                           PeriodicSamplerType;
  typedef otb::OGRDataToSamplePositionFilter<
    FloatVectorImageType,
    UInt8ImageType,
    otb::RandomSampler>                             RandomSamplerType;
  typedef otb::SamplingRateCalculator               RateCalculatorType;
62 63
  
  typedef std::map<std::string, unsigned long>      ClassCountMapType;
64
  typedef RateCalculatorType::MapRateType           MapRateType;
65
  typedef itk::VariableLengthVector<float> MeasurementType;
66
  typedef otb::StatisticsXMLFileReader<MeasurementType> XMLReaderType;
67

68 69 70 71
  typedef otb::GeometriesSet GeometriesType;

  typedef otb::GeometriesProjectionFilter ProjectionFilterType;

72
private:
73
  SampleSelection()
74
    {
75 76 77 78
    m_Periodic = PeriodicSamplerType::New();
    m_Random = RandomSamplerType::New();
    m_ReaderStat = XMLReaderType::New();
    m_RateCalculator = RateCalculatorType::New();
79 80 81 82
    }

  void DoInit()
  {
83
    SetName("SampleSelection");
84 85 86
    SetDescription("Selects samples from a training vector data set.");

    // Documentation
87
    SetDocName("Sample Selection");
88 89 90 91 92
    SetDocLongDescription("The application selects a set of samples from geometries "
      "intended for training (they should have a field giving the associated "
      "class). \n\nFirst of all, the geometries must be analyzed by the PolygonClassStatistics application "
      "to compute statistics about the geometries, which are summarized in an xml file. "
      "\nThen, this xml file must be given as input to this application (parameter instats).\n\n"
93 94 95
      "The input support image and the input training vectors shall be given in "
      "parameters 'in' and 'vec' respectively. Only the sampling grid (origin, size, spacing)"
      "will be read in the input image.\n"
96 97 98 99 100
      "There are several strategies to select samples (parameter strategy) : \n\n"
      "  - smallest (default) : select the same number of sample in each class" 
      " so that the smallest one is fully sampled.\n"
      "  - constant : select the same number of samples N in each class" 
      " (with N below or equal to the size of the smallest class).\n"
101
      "  - byclass : set the required number for each class manually, with an input CSV file"
102
      " (first column is class name, second one is the required samples number).\n\n"
103 104
      "  - percent: set a target global percentage of samples to use. Class proportions will be respected. \n\n"
      "  - total: set a target total number of samples to use. Class proportions will be respected. \n\n"
105
      "There is also a choice on the sampling type to performs : \n\n"
106
      "  - periodic : select samples uniformly distributed\n"
107
      "  - random : select samples randomly distributed\n\n"
108
      "Once the strategy and type are selected, the application outputs samples positions"
109 110
      "(parameter out).\n\n"
      
111
      "The other parameters to look at are : \n\n"
112 113
      "  - layer : index specifying from which layer to pick geometries.\n"
      "  - field : set the field name containing the class.\n"
114
      "  - mask : an optional raster mask can be used to discard samples.\n"
115
      "  - outrates : allows outputting a CSV file that summarizes the sampling rates for each class.\n"
116 117
      
      "\nAs with the PolygonClassStatistics application, different types  of geometry are supported : "
118
      "polygons, lines, points. \nThe behavior of this application is different for each type of geometry : \n\n"
119 120
      "  - polygon: select points whose center is inside the polygon\n"
      "  - lines  : select points intersecting the line\n"
Christophe Palmann's avatar
Christophe Palmann committed
121
      "  - points : select closest point to the provided point\n");
122 123 124 125 126 127 128 129 130 131 132 133
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");

    AddDocTag(Tags::Learning);

    AddParameter(ParameterType_InputImage,  "in",   "InputImage");
    SetParameterDescription("in", "Support image that will be classified");

    AddParameter(ParameterType_InputImage,  "mask",   "InputMask");
    SetParameterDescription("mask", "Validity mask (only pixels corresponding to a mask value greater than 0 will be used for statistics)");
    MandatoryOff("mask");
134

135 136
    AddParameter(ParameterType_InputFilename, "vec", "Input vectors");
    SetParameterDescription("vec","Input geometries to analyse");
137

138 139
    AddParameter(ParameterType_OutputFilename, "out", "Output vectors");
    SetParameterDescription("out","Output resampled geometries");
140

141 142
    AddParameter(ParameterType_InputFilename, "instats", "Input Statistics");
    SetParameterDescription("instats","Input file storing statistics (XML format)");
143

144
    AddParameter(ParameterType_OutputFilename, "outrates", "Output rates");
145
    SetParameterDescription("outrates","Output rates (CSV formatted)");
146 147
    MandatoryOff("outrates");

148 149 150 151 152 153
    AddParameter(ParameterType_Choice, "sampler", "Sampler type");
    SetParameterDescription("sampler", "Type of sampling (periodic, pattern based, random)");

    AddChoice("sampler.periodic","Periodic sampler");
    SetParameterDescription("sampler.periodic","Takes samples regularly spaced");

154 155 156 157 158
    AddParameter(ParameterType_Int, "sampler.periodic.jitter","Jitter amplitude");
    SetParameterDescription("sampler.periodic.jitter", "Jitter amplitude added during sample selection (0 = no jitter)");
    SetDefaultParameterInt("sampler.periodic.jitter",0);
    MandatoryOff("sampler.periodic.jitter");

159
    AddChoice("sampler.random","Random sampler");
160
    SetParameterDescription("sampler.random","The positions to select are randomly shuffled.");
161

162
    AddParameter(ParameterType_Choice, "strategy", "Sampling strategy");
163

164 165
    AddChoice("strategy.byclass","Set samples count for each class");
    SetParameterDescription("strategy.byclass","Set samples count for each class");
166

167 168
    AddParameter(ParameterType_InputFilename, "strategy.byclass.in", "Number of samples by class");
    SetParameterDescription("strategy.byclass.in", "Number of samples by class "
169 170
      "(CSV format with class name in 1st column and required samples in the 2nd.");

171 172
    AddChoice("strategy.constant","Set the same samples counts for all classes");
    SetParameterDescription("strategy.constant","Set the same samples counts for all classes");
173

174 175
    AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes");
    SetParameterDescription("strategy.constant.nb", "Number of samples for all classes");
176

177 178 179 180 181 182 183 184
    AddChoice("strategy.percent","Use a percentage of the samples available for each class");
    SetParameterDescription("strategy.percent","Use a percentage of the samples available for each class");

    AddParameter(ParameterType_Float,"strategy.percent.p","The percentage to use");
    SetParameterDescription("strategy.percent.p","The percentage to use");
    SetMinimumParameterFloatValue("strategy.percent.p",0);
    SetMaximumParameterFloatValue("strategy.percent.p",1);
    SetDefaultParameterFloat("strategy.percent.p",0.5);
185 186 187 188 189 190 191 192

    AddChoice("strategy.total","Set the total number of samples to generate, and use class proportions.");
    SetParameterDescription("strategy.total","Set the total number of samples to generate, and use class proportions.");

    AddParameter(ParameterType_Int,"strategy.total.v","The number of samples to generate");
    SetParameterDescription("strategy.total.v","The number of samples to generate");
    SetMinimumParameterIntValue("strategy.total.v",1);
    SetDefaultParameterInt("strategy.total.v",1000);
193
    
194 195
    AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
    SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled");
196

197 198
    AddChoice("strategy.all","Take all samples");
    SetParameterDescription("strategy.all","Take all samples");
199

200
    // Default strategy : smallest
201
    SetParameterString("strategy","smallest");
202

203
    AddParameter(ParameterType_ListView, "field", "Field Name");
204
    SetParameterDescription("field","Name of the field carrying the class name in the input vectors.");
205
    SetListViewSingleSelectionMode("field",true);
206

207 208 209 210
    AddParameter(ParameterType_Int, "layer", "Layer Index");
    SetParameterDescription("layer", "Layer index to read in the input vector file.");
    MandatoryOff("layer");
    SetDefaultParameterInt("layer",0);
211

212 213
    AddRAMParameter();

214 215
    AddRANDParameter();

216 217 218 219 220 221 222 223 224 225
    // Doc example parameter settings
    SetDocExampleParameterValue("in", "support_image.tif");
    SetDocExampleParameterValue("vec", "variousVectors.sqlite");
    SetDocExampleParameterValue("field", "label");
    SetDocExampleParameterValue("instats","apTvClPolygonClassStatisticsOut.xml");
    SetDocExampleParameterValue("out","resampledVectors.sqlite");
  }

  void DoUpdateParameters()
  {
226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
 if ( HasValue("vec") )
      {
      std::string vectorFile = GetParameterString("vec");
      ogr::DataSource::Pointer ogrDS =
        ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
      ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
      ogr::Feature feature = layer.ogr().GetNextFeature();

      ClearChoices("field");
      
      for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
        {
        std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
        key = item;
        std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum);
        std::transform(key.begin(), end, key.begin(), tolower);
        
        OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
        
245
        if(fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType))
246 247 248 249 250 251
          {
          std::string tmpKey="field."+key.substr(0, end - key.begin());
          AddChoice(tmpKey,item);
          }
        }
      }
252 253
  }

254 255 256 257 258 259
  void DoExecute()
    {
    // Clear state
    m_RateCalculator->ClearRates();
    m_Periodic->GetFilter()->ClearOutputs();
    m_Random->GetFilter()->ClearOutputs();
260 261 262 263

    // Setup ram
    m_Periodic->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram"));
    m_Random->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram"));
264 265 266 267 268 269 270 271 272 273 274

    // Get field name
    std::vector<int> selectedCFieldIdx = GetSelectedItems("field");
    
    if(selectedCFieldIdx.empty())
      {
      otbAppLogFATAL(<<"No field has been selected for data labelling!");
      }
    
    std::vector<std::string> cFieldNames = GetChoiceNames("cfield");  
    std::string fieldName = cFieldNames[selectedCFieldIdx.front()];
275 276 277 278 279
    
    m_ReaderStat->SetFileName(this->GetParameterString("instats"));
    ClassCountMapType classCount = m_ReaderStat->GetStatisticMapByName<ClassCountMapType>("samplesPerClass");
    m_RateCalculator->SetClassCount(classCount);
    
280
    switch (this->GetParameterInt("strategy"))
281
      {
282 283 284
      // byclass
      case 0:
        {
285
        otbAppLogINFO("Sampling strategy : set number of samples for each class");
286
        ClassCountMapType requiredCount = 
287
          otb::SamplingRateCalculator::ReadRequiredSamples(this->GetParameterString("strategy.byclass.in"));
288 289 290 291 292
        m_RateCalculator->SetNbOfSamplesByClass(requiredCount);
        }
      break;
      // constant
      case 1:
293 294
        {
        otbAppLogINFO("Sampling strategy : set a constant number of samples for all classes");
295
        m_RateCalculator->SetNbOfSamplesAllClasses(GetParameterInt("strategy.constant.nb"));
296
        }
297
      break;
298
      // percent
299
      case 2:
300
      {
301
      otbAppLogINFO("Sampling strategy: set a percentage of samples for each class.");
302 303 304
      m_RateCalculator->SetPercentageOfSamples(this->GetParameterFloat("strategy.percent.p"));
      }
      break;
305
      // total
306
      case 3:
307 308 309 310 311 312 313 314
      {
      otbAppLogINFO("Sampling strategy: set the total number of samples to generate, use classes proportions.");
      m_RateCalculator->SetTotalNumberOfSamples(this->GetParameterInt("strategy.total.v"));
      }
      break;

      // smallest class
      case 4:
315 316
        {
        otbAppLogINFO("Sampling strategy : fit the number of samples based on the smallest class");
317
        m_RateCalculator->SetMinimumNbOfSamplesByClass();
318 319 320
        }
      break;
      // all samples
321
      case 5:
322 323 324 325
        {
        otbAppLogINFO("Sampling strategy : take all samples");
        m_RateCalculator->SetAllSamples();
        }
326 327
      break;
      default:
328
        otbAppLogFATAL("Strategy mode unknown :"<<this->GetParameterString("strategy"));
329
      break;
330
      }
331 332
      
    if (IsParameterEnabled("outrates") && HasValue("outrates"))
333
      {
334
      m_RateCalculator->Write(this->GetParameterString("outrates"));
335
      }
336 337
    
    MapRateType rates = m_RateCalculator->GetRatesByClass();
338 339 340
    std::ostringstream oss;
    oss << " className  requiredSamples  totalSamples  rate" << std::endl;
    MapRateType::const_iterator itRates = rates.begin();
341
    unsigned int overflowCount = 0;
342 343 344
    for(; itRates != rates.end(); ++itRates)
      {
      otb::SamplingRateCalculator::TripletType tpt = itRates->second;
345 346 347 348 349 350 351
      oss << itRates->first << "\t" << tpt.Required << "\t" << tpt.Tot << "\t" << tpt.Rate;
      if (tpt.Required > tpt.Tot)
        {
        overflowCount++;
        oss << "\t[OVERFLOW]";
        }
      oss << std::endl;
352 353
      }
    otbAppLogINFO("Sampling rates : " << oss.str());
354 355 356 357 358
    if (overflowCount)
      {
      std::string plural(overflowCount>1?"s":"");
      otbAppLogWARNING(<< overflowCount << " case"<<plural<<" of overflow detected! (requested number of samples higher than total available samples)");
      }
359 360 361 362 363

    // Open input geometries
    otb::ogr::DataSource::Pointer vectors =
      otb::ogr::DataSource::New(this->GetParameterString("vec"));

364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387
    // Reproject geometries
    FloatVectorImageType::Pointer inputImg = this->GetParameterImage("in");
    std::string imageProjectionRef = inputImg->GetProjectionRef();
    FloatVectorImageType::ImageKeywordlistType imageKwl =
      inputImg->GetImageKeywordlist();
    std::string vectorProjectionRef =
      vectors->GetLayer(GetParameterInt("layer")).GetProjectionRef();

    otb::ogr::DataSource::Pointer reprojVector = vectors;
    GeometriesType::Pointer inputGeomSet;
    ProjectionFilterType::Pointer geometriesProjFilter;
    GeometriesType::Pointer outputGeomSet;
    bool doReproj = true;
    // don't reproject for these cases
    if (vectorProjectionRef.empty() ||
        (imageProjectionRef == vectorProjectionRef) ||
        (imageProjectionRef.empty() && imageKwl.GetSize() == 0))
      doReproj = false;
  
    if (doReproj)
      {
      inputGeomSet = GeometriesType::New(vectors);
      reprojVector = otb::ogr::DataSource::New();
      outputGeomSet = GeometriesType::New(reprojVector);
388
      // Filter instantiation
389 390 391 392 393 394 395 396 397 398 399 400
      geometriesProjFilter = ProjectionFilterType::New();
      geometriesProjFilter->SetInput(inputGeomSet);
      if (imageProjectionRef.empty())
        {
        geometriesProjFilter->SetOutputKeywordList(inputImg->GetImageKeywordlist()); // nec qd capteur
        }
      geometriesProjFilter->SetOutputProjectionRef(imageProjectionRef);
      geometriesProjFilter->SetOutput(outputGeomSet);
      otbAppLogINFO("Reprojecting input vectors...");
      geometriesProjFilter->Update();
      }

401 402 403 404 405
    // Create output dataset for sample positions
    otb::ogr::DataSource::Pointer outputSamples =
      otb::ogr::DataSource::New(this->GetParameterString("out"),otb::ogr::DataSource::Modes::Overwrite);
    
    switch (this->GetParameterInt("sampler"))
406
      {
407 408 409
      // periodic
      case 0:
        {
410 411 412 413
        PeriodicSamplerType::SamplerParameterType param;
        param.Offset = 0;
        param.MaxJitter = this->GetParameterInt("sampler.periodic.jitter");

414
        m_Periodic->SetInput(this->GetParameterImage("in"));
415
        m_Periodic->SetOGRData(reprojVector);
416
        m_Periodic->SetOutputPositionContainerAndRates(outputSamples, rates);
417
        m_Periodic->SetFieldName(fieldName);
418
        m_Periodic->SetLayerIndex(this->GetParameterInt("layer"));
419
        m_Periodic->SetSamplerParameters(param);
420 421 422 423
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
          m_Periodic->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
          }
424
        m_Periodic->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
425
        AddProcess(m_Periodic->GetStreamer(),"Selecting positions with periodic sampler...");
426 427 428 429
        m_Periodic->Update();
        }
      break;
      // random
430
      case 1:
431 432
        {
        m_Random->SetInput(this->GetParameterImage("in"));
433
        m_Random->SetOGRData(reprojVector);
434
        m_Random->SetOutputPositionContainerAndRates(outputSamples, rates);
435
        m_Random->SetFieldName(fieldName);
436 437 438 439 440
        m_Random->SetLayerIndex(this->GetParameterInt("layer"));
        if (IsParameterEnabled("mask") && HasValue("mask"))
          {
          m_Random->SetMask(this->GetParameterImage<UInt8ImageType>("mask"));
          }
441
        m_Random->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram"));
442
        AddProcess(m_Random->GetStreamer(),"Selecting positions with random sampler...");
443 444 445 446 447 448
        m_Random->Update();
        }
      break;
      default:
        otbAppLogFATAL("Sampler type unknown :"<<this->GetParameterString("sampler"));
      break;
449 450 451
      }
  }

452 453 454 455 456 457
  RateCalculatorType::Pointer m_RateCalculator;
  
  PeriodicSamplerType::Pointer m_Periodic;
  RandomSamplerType::Pointer m_Random;
  
  XMLReaderType::Pointer m_ReaderStat;
458 459 460 461 462
};

} // end of namespace Wrapper
} // end of namespace otb

463
OTB_APPLICATION_EXPORT(otb::Wrapper::SampleSelection)