otbSampleAugmentation.cxx 10.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbOGRDataSourceWrapper.h"
24
#include "otbSampleAugmentationFilter.h"
25 26 27 28 29 30

namespace otb
{
namespace Wrapper
{

31

32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
class SampleAugmentation : public Application
{
public:
  /** Standard class typedefs. */
  typedef SampleAugmentation              Self;
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

  itkTypeMacro(SampleAugmentation, otb::Application);

  /** Filters typedef */
47 48 49
  using FilterType = otb::SampleAugmentationFilter;
  using SampleType = FilterType::SampleType;
  using SampleVectorType = FilterType::SampleVectorType;
50 51 52 53 54 55 56 57 58 59

private:
  SampleAugmentation() {}

  void DoInit()
  {
    SetName("SampleAugmentation");
    SetDescription("Generates synthetic samples from a sample data file.");

    // Documentation
Jordi Inglada's avatar
Jordi Inglada committed
60
    SetDocName("Sample Augmentation");
61 62 63 64 65 66 67 68 69 70
    SetDocLongDescription("The application takes a sample data file as "
                          "generated by the SampleExtraction application and "
                          "generates synthetic samples to increase the number of "
                          "available samples.");
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");

    AddDocTag(Tags::Learning);

71 72
    AddParameter(ParameterType_InputFilename, "in", "Input samples");
    SetParameterDescription("in","Vector data file containing samples (OGR format)");
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97

    AddParameter(ParameterType_OutputFilename, "out", "Output samples");
    SetParameterDescription("out","Output vector data file storing new samples"
                            "(OGR format). If not given, the input vector data file is updated");
    MandatoryOff("out");

    AddParameter(ParameterType_ListView, "field", "Field Name");
    SetParameterDescription("field","Name of the field carrying the class name in the input vectors.");
    SetListViewSingleSelectionMode("field",true);
    
    AddParameter(ParameterType_Int, "layer", "Layer Index");
    SetParameterDescription("layer", "Layer index to read in the input vector file.");
    MandatoryOff("layer");
    SetDefaultParameterInt("layer",0);

    AddParameter(ParameterType_Int, "label", "Label of the class to be augmented");
    SetParameterDescription("label", "Label of the class of the input file for which "
                            "new samples will be generated.");
    SetDefaultParameterInt("label",1);

    AddParameter(ParameterType_Int, "samples", "Number of generated samples");
    SetParameterDescription("samples", "Number of synthetic samples that will "
                            "be generated.");
    SetDefaultParameterInt("samples",100);

98 99 100 101
    AddParameter(ParameterType_ListView, "exclude", "Field names for excluded features.");
    SetParameterDescription("exclude",
                            "List of field names in the input vector data that will not be generated in the output file.");

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
    AddParameter(ParameterType_Choice, "strategy", "Augmentation strategy");

    AddChoice("strategy.replicate","Replicate input samples");
    SetParameterDescription("strategy.replicate","The new samples are generated "
                            "by replicating input samples which are randomly "
                            "selected with replacement.");

    AddChoice("strategy.jitter","Jitter input samples");
    SetParameterDescription("strategy.jitter","The new samples are generated "
                            "by adding gaussian noise to input samples which are "
                            "randomly selected with replacement.");
    AddParameter(ParameterType_Float, "strategy.jitter.stdfactor", 
                 "Factor for dividing the standard deviation of each feature");
    SetParameterDescription("strategy.jitter.stdfactor", 
                            "The noise added to the input samples will have the "
                            "standard deviation of the input features divided "
                            "by the value of this parameter. ");
119
    SetDefaultParameterFloat("strategy.jitter.stdfactor",10);
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137

    AddChoice("strategy.smote","Smote input samples");
    SetParameterDescription("strategy.smote","The new samples are generated "
                            "by using the SMOTE algorithm (http://dx.doi.org/10.1613/jair.953) "
                            "on input samples which are "
                            "randomly selected with replacement.");
    AddParameter(ParameterType_Int, "strategy.smote.neighbors", 
                 "Number of nearest neighbors.");
    SetParameterDescription("strategy.smote.neighbors", 
                            "Number of nearest neighbors to be used in the "
                            "SMOTE algorithm");
    SetDefaultParameterFloat("strategy.smote.neighbors", 5);

    AddParameter(ParameterType_Int, "seed", 
                 "Random seed.");
    SetParameterDescription("seed", 
                            "Seed for the random number generator.");
    MandatoryOff("seed");
138 139

    // Doc example parameter settings
140
    SetDocExampleParameterValue("in", "samples.sqlite");
141 142 143 144
    SetDocExampleParameterValue("field", "class");
    SetDocExampleParameterValue("label", "3");
    SetDocExampleParameterValue("samples", "100");
    SetDocExampleParameterValue("out","augmented_samples.sqlite");
145
    SetDocExampleParameterValue( "exclude", "OGC_FID name class originfid" );
146 147
    SetDocExampleParameterValue("strategy", "smote");
    SetDocExampleParameterValue("strategy.smote.neighbors", "5");
148 149 150 151 152 153

    SetOfficialDocLink();
  }

  void DoUpdateParameters()
  {
154
    if ( HasValue("in") )
155
      {
156
      std::string vectorFile = GetParameterString("in");
157 158 159 160 161
      ogr::DataSource::Pointer ogrDS =
        ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read);
      ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
      ogr::Feature feature = layer.ogr().GetNextFeature();

162
      ClearChoices("exclude");
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
      ClearChoices("field");
      
      for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
        {
        std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
        key = item;
        std::string::iterator end = std::remove_if(key.begin(),key.end(),
                                                   [](auto c){return !std::isalnum(c);});
        std::transform(key.begin(), end, key.begin(), tolower);
        
        OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
        
        if(fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType))
          {
          std::string tmpKey="field."+key.substr(0, end - key.begin());
          AddChoice(tmpKey,item);
          }
180 181 182 183 184
        if( fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) || fieldType == OFTReal )
          {
          std::string tmpKey = "exclude." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) );
          AddChoice( tmpKey, item );
          }
185 186 187 188 189 190 191 192 193 194
        }
      }
  }

  void DoExecute()
    {
    ogr::DataSource::Pointer vectors;
    ogr::DataSource::Pointer output;
    if (IsParameterEnabled("out") && HasValue("out"))
      {
195
      vectors = ogr::DataSource::New(this->GetParameterString("in"));
196 197 198 199 200 201
      output = ogr::DataSource::New(this->GetParameterString("out"),
                                    ogr::DataSource::Modes::Overwrite);
      }
    else
      {
      // Update mode
202 203
      vectors = ogr::DataSource::New(this->GetParameterString("in"),
                                     ogr::DataSource::Modes::Update_LayerUpdate);
204 205 206 207 208 209 210 211 212 213 214 215 216 217
      output = vectors;
      }

    // Retrieve the field name
    std::vector<int> selectedCFieldIdx = GetSelectedItems("field");

    if(selectedCFieldIdx.empty())
      {
      otbAppLogFATAL(<<"No field has been selected for data labelling!");
      }

  std::vector<std::string> cFieldNames = GetChoiceNames("field");  
  std::string fieldName = cFieldNames[selectedCFieldIdx.front()];
    
218 219 220 221
  std::vector<std::string> excludedFields = 
    GetExcludedFields( GetChoiceNames( "exclude" ), 
                       GetSelectedItems( "exclude" ));
  for(const auto& ef : excludedFields)
222
    otbAppLogINFO("Excluding feature " << ef << '\n');
223

224 225
  int seed = std::time(nullptr);
  if(IsParameterEnabled("seed")) seed = this->GetParameterInt("seed");
226 227 228 229 230 231 232 233 234


  FilterType::Pointer filter = FilterType::New();
  filter->SetInput(vectors);
  filter->SetLayer(this->GetParameterInt("layer"));
  filter->SetNumberOfSamples(this->GetParameterInt("samples"));
  filter->SetOutputSamples(output);
  filter->SetClassFieldName(fieldName);
  filter->SetLabel(this->GetParameterInt("label"));
235
  filter->SetExcludedFields(excludedFields);
236
  filter->SetSeed(seed);
237 238 239 240 241 242
  switch (this->GetParameterInt("strategy"))
    {
    // replicate
    case 0:
    {
    otbAppLogINFO("Augmentation strategy : replicate");
243
    filter->SetStrategy(FilterType::Strategy::Replicate);
244
    }
245
      break;
246 247 248 249
    // jitter
    case 1:
    {
    otbAppLogINFO("Augmentation strategy : jitter");
250
    filter->SetStrategy(FilterType::Strategy::Jitter);
251
    filter->SetStdFactor(this->GetParameterFloat("strategy.jitter.stdfactor"));
252 253 254 255 256
    }
    break;
    case 2:
    {
    otbAppLogINFO("Augmentation strategy : smote");
257
    filter->SetStrategy(FilterType::Strategy::Smote);
258
    filter->SetSmoteNeighbors(this->GetParameterInt("strategy.smote.neighbors"));
259 260 261
    }
    break;
    }
262
  filter->Update();
263 264 265
  output->SyncToDisk();
    }

266

267 268
  std::vector<std::string> GetExcludedFields(const std::vector<std::string>& fieldNames,
                                             const std::vector<int>& selectedIdx)
269 270 271 272 273 274 275 276 277
  {
    auto nbFeatures = static_cast<unsigned int>(selectedIdx.size());
    std::vector<std::string> result( nbFeatures );
    for( unsigned int i = 0; i < nbFeatures; ++i )
      {
      result[i] = fieldNames[selectedIdx[i]];
      }
    return result;
  }
278 279

};
280 281 282 283 284

} // end of namespace Wrapper
} // end of namespace otb

OTB_APPLICATION_EXPORT(otb::Wrapper::SampleAugmentation)