otbVectorDimensionalityReduction.cxx 17.8 KB
Newer Older
1
/*
2
 * Copyright (C) 2005-2020 Centre National d'Etudes Spatiales (CNES)
Guillaume Pasero's avatar
Guillaume Pasero committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20
21
22
23
24
25
26
27
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbOGRFeatureWrapper.h"
#include "itkVariableLengthVector.h"
#include "otbStatisticsXMLFileReader.h"
#include "itkListSample.h"
#include "otbShiftScaleSampleListFilter.h"
28
#include "otbDimensionalityReductionModelFactory.h"
29
#include <time.h>
30

31
32
33
34
namespace otb
{
namespace Wrapper
{
35
36
37
38
39
/**
 * \class VectorDimensionalityReduction
 *
 * Apply a dimensionality reduction model on a vector file
 */
40
class VectorDimensionalityReduction : public Application
41
{
42
public:
43
  /** Standard class typedefs. */
44
  typedef VectorDimensionalityReduction Self;
45
46
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
47
48
49
50
51
52
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);
  itkTypeMacro(Self, Application)

53
54
55
56
57
58
59
      /** Filters typedef */
      typedef float ValueType;
  typedef itk::VariableLengthVector<ValueType>         InputSampleType;
  typedef itk::Statistics::ListSample<InputSampleType> ListSampleType;
  typedef MachineLearningModel<itk::VariableLengthVector<ValueType>, itk::VariableLengthVector<ValueType>> DimensionalityReductionModelType;
  typedef DimensionalityReductionModelFactory<ValueType, ValueType>                                        DimensionalityReductionModelFactoryType;
  typedef DimensionalityReductionModelType::Pointer ModelPointerType;
60
61

  /** Statistics Filters typedef */
62
63
64
  typedef itk::VariableLengthVector<ValueType>          MeasurementType;
  typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
65
66

protected:
67
  ~VectorDimensionalityReduction() override
68
  {
69
    DimensionalityReductionModelFactoryType::CleanFactories();
70
  }
71

Victor Poughon's avatar
Victor Poughon committed
72
private:
73
  void DoInit() override
74
  {
75
    SetName("VectorDimensionalityReduction");
76
77
78
    SetDescription(
        "Performs dimensionality reduction of the input vector data "
        "according to a model file.");
79
    SetDocAuthors("OTB-Team");
80
81
82
83
    SetDocLongDescription(
        "This application performs a vector data "
        "dimensionality reduction based on a model file produced by the "
        "TrainDimensionalityReduction application.");
84
    SetDocSeeAlso("TrainDimensionalityReduction");
Guillaume Pasero's avatar
Guillaume Pasero committed
85
    SetDocLimitations("None");
86
87
88
    AddDocTag(Tags::Learning);

    AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data");
89
    SetParameterDescription("in", "The input vector data to reduce.");
90
91

    AddParameter(ParameterType_InputFilename, "instat", "Statistics file");
92
93
94
95
    SetParameterDescription("instat",
                            "An XML file containing mean and standard "
                            "deviation to center and reduce samples before dimensionality reduction "
                            "(produced by ComputeImagesStatistics application).");
96
97
98
    MandatoryOff("instat");

    AddParameter(ParameterType_InputFilename, "model", "Model file");
99
100
101
102
103
104
105
106
107
108
109
    SetParameterDescription("model",
                            "A model file (produced by the "
                            "TrainDimensionalityReduction application,");

    AddParameter(ParameterType_OutputFilename, "out",
                 "Output vector data file "
                 "containing the reduced vector");
    SetParameterDescription("out",
                            "Output vector data file storing sample "
                            "values (OGR format). If not given, the input vector data file is used. "
                            "In overwrite mode, the original features will be lost.");
110
    MandatoryOff("out");
111

112
    AddParameter(ParameterType_Field, "feat", "Input features to use for reduction");
Victor Poughon's avatar
Victor Poughon committed
113
    SetParameterDescription("feat", "List of field names in the input vector data used as features for reduction.");
114
    SetVectorData("feat", "in");
115

Victor Poughon's avatar
Victor Poughon committed
116
    AddParameter(ParameterType_Choice, "featout", "Output feature");
117
118
119
120
121
122
    SetParameterDescription("featout", "Naming of output features");

    AddChoice("featout.prefix", "Prefix");
    SetParameterDescription("featout.prefix", "Use a name prefix");

    AddParameter(ParameterType_String, "featout.prefix.name", "Feature name prefix");
123
124
125
126
    SetParameterDescription("featout.prefix.name",
                            "Name prefix for output "
                            "features. This prefix is followed by the numeric index of each output feature.");
    SetParameterString("featout.prefix.name", "reduced_", false);
127

128
    AddChoice("featout.list", "List");
129
130
    SetParameterDescription("featout.list", "Use a list with all names");

131
    AddParameter(ParameterType_Field, "featout.list.names", "Feature name list");
Victor Poughon's avatar
Victor Poughon committed
132
133
134
    SetParameterDescription("featout.list.names",
                            "List of field names for the output "
                            "features which result from the reduction.");
135
    SetVectorData("featout.list.names", "in");
Victor Poughon's avatar
Victor Poughon committed
136
137
138
139
140
141

    AddParameter(ParameterType_Int, "pcadim", "Principal component dimension");
    SetParameterDescription("pcadim",
                            "This optional parameter can be set to "
                            "reduce the number of eignevectors used in the PCA model file. This "
                            "parameter can't be used for other models");
142
    MandatoryOff("pcadim");
Victor Poughon's avatar
Victor Poughon committed
143
144

    AddParameter(ParameterType_Choice, "mode", "Writing mode");
145
146
147
148
    SetParameterDescription("mode",
                            "This parameter determines if the output "
                            "file is overwritten or updated [overwrite/update]. If an output file "
                            "name is given, the original file is copied before creating the new features.");
149
150

    AddChoice("mode.overwrite", "Overwrite");
Victor Poughon's avatar
Victor Poughon committed
151
    SetParameterDescription("mode.overwrite", "Overwrite mode");
152
153
154

    AddChoice("mode.update", "Update");
    SetParameterDescription("mode.update", "Update mode");
155
156
157
158
159
160
161

    // Doc example parameter settings
    SetDocExampleParameterValue("in", "vectorData.shp");
    SetDocExampleParameterValue("instat", "meanVar.xml");
    SetDocExampleParameterValue("model", "model.txt");
    SetDocExampleParameterValue("out", "vectorDataOut.shp");
    SetDocExampleParameterValue("feat", "perimeter area width");
Victor Poughon's avatar
Victor Poughon committed
162
    SetOfficialDocLink();
163
  }
164

165
  void DoUpdateParameters() override
166
167
  {
    if (HasValue("in"))
168
    {
169
      std::string                   shapefile = GetParameterString("in");
170
      otb::ogr::DataSource::Pointer ogrDS;
171
172
173
174
175
      OGRSpatialReference           oSRS("");
      std::vector<std::string>      options;
      ogrDS                     = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
      otb::ogr::Layer layer     = ogrDS->GetLayer(0);
      OGRFeatureDefn& layerDefn = layer.GetLayerDefn();
176
177
      ClearChoices("feat");

178
179
180
181
182
183
184
185
      for (int iField = 0; iField < layerDefn.GetFieldCount(); iField++)
      {
        std::string           item = layerDefn.GetFieldDefn(iField)->GetNameRef();
        std::string           key(item);
        std::string::iterator end = std::remove_if(key.begin(), key.end(), [](char c) { return !std::isalnum(c); });
        std::transform(key.begin(), end, key.begin(), tolower);
        std::string tmpKey = "feat." + key.substr(0, static_cast<unsigned long>(end - key.begin()));
        AddChoice(tmpKey, item);
186
187
      }
    }
188
  }
189

190
  void DoExecute() override
191
  {
192
193
    clock_t tic = clock();

194
195
196
197
198
199
    std::string                   shapefile    = GetParameterString("in");
    otb::ogr::DataSource::Pointer source       = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
    otb::ogr::Layer               layer        = source->GetLayer(0);
    ListSampleType::Pointer       input        = ListSampleType::New();
    std::vector<int>              inputIndexes = GetSelectedItems("feat");
    int                           nbFeatures   = inputIndexes.size();
200
201

    input->SetMeasurementVectorSize(nbFeatures);
202
    otb::ogr::Layer::const_iterator it    = layer.cbegin();
203
204
    otb::ogr::Layer::const_iterator itEnd = layer.cend();

205
206
207
    // Get the list of non-selected field indexes
    // /!\ The 'feat' is assumed to expose all available fields, hence the
    // mapping between GetSelectedItems() and OGR field indexes
208
209
210
    OGRFeatureDefn& inLayerDefn = layer.GetLayerDefn();
    std::set<int>   otherInputFields;
    for (int i = 0; i < inLayerDefn.GetFieldCount(); i++)
211
      otherInputFields.insert(i);
212
    for (int k = 0; k < nbFeatures; k++)
213
214
      otherInputFields.erase(inputIndexes[k]);

215
216
    for (; it != itEnd; ++it)
    {
217
218
      MeasurementType mv;
      mv.SetSize(nbFeatures);
219
220
221

      for (int idx = 0; idx < nbFeatures; ++idx)
      {
222
223
        switch ((*it)[inputIndexes[idx]].GetType())
        {
224
225
226
227
228
229
230
231
232
233
234
        case OFTInteger:
          mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<int>());
          break;
        case OFTInteger64:
          mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<int>());
          break;
        case OFTReal:
          mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<double>());
          break;
        default:
          itkExceptionMacro(<< "incorrect field type: " << (*it)[inputIndexes[idx]].GetType() << ".");
235
236
        }
      }
237
238
      input->PushBack(mv);
    }
239

240
    /** Statistics for shift/scale */
241
242
243
244
    MeasurementType meanMeasurementVector;
    MeasurementType stddevMeasurementVector;

    if (HasValue("instat") && IsParameterEnabled("instat"))
245
    {
246
      StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
247
      std::string               XMLfile          = GetParameterString("instat");
248
      statisticsReader->SetFileName(XMLfile);
249
      meanMeasurementVector   = statisticsReader->GetStatisticVectorByName("mean");
250
      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
251
252
      otbAppLogINFO("Mean used: " << meanMeasurementVector);
      otbAppLogINFO("Standard deviation used: " << stddevMeasurementVector);
253
    }
254
    else
255
    {
256
257
258
259
      meanMeasurementVector.SetSize(nbFeatures);
      meanMeasurementVector.Fill(0.);
      stddevMeasurementVector.SetSize(nbFeatures);
      stddevMeasurementVector.Fill(1.);
260
    }
261
262
263
264
265
266
267

    ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
    trainingShiftScaleFilter->SetInput(input);
    trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
    trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
    trainingShiftScaleFilter->Update();

268
    otbAppLogINFO("Loading model");
269
    /** Read the model */
270
271
    m_Model = DimensionalityReductionModelFactoryType::CreateDimensionalityReductionModel(GetParameterString("model"),
                                                                                          DimensionalityReductionModelFactoryType::ReadMode);
272
    if (m_Model.IsNull())
273
274
275
    {
      otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
    }
276
    m_Model->Load(GetParameterString("model"));
277
    if (HasValue("pcadim") && IsParameterEnabled("pcadim"))
278
    {
279
280
      std::string modelName(m_Model->GetNameOfClass());
      if (modelName != "PCAModel")
281
282
      {
        otbAppLogFATAL(<< "Can't set 'pcadim' on a model : " << modelName);
283
      }
284
285
286
      m_Model->SetDimension(GetParameterInt("pcadim"));
    }
    otbAppLogINFO("Model loaded, dimension : " << m_Model->GetDimension());
287

288
    /** Perform Dimensionality Reduction */
289
    ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput();
290
    ListSampleType::Pointer target     = m_Model->PredictBatch(listSample);
291

292
    /** Create/Update Output Shape file */
293
    ogr::DataSource::Pointer output;
294
295
    ogr::DataSource::Pointer buffer     = ogr::DataSource::New();
    bool                     updateMode = false;
296
297

    if (IsParameterEnabled("out") && HasValue("out"))
298
    {
299
      // Create new OGRDataSource
300
301
302
303
      if (GetParameterString("mode") == "overwrite")
      {
        output                   = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
        otb::ogr::Layer newLayer = output->CreateLayer(GetParameterString("out"), const_cast<OGRSpatialReference*>(layer.GetSpatialRef()), layer.GetGeomType());
304
305
        // Copy existing fields except the ones selected for reduction
        for (const int& k : otherInputFields)
306
        {
307
308
309
          OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
          newLayer.CreateField(fieldDefn);
        }
310
311
312
313
      }
      else if (GetParameterString("mode") == "update")
      {
        // output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Update_LayerCreateOnly);
314
        // Update mode
315
316
317
        otb::ogr::DataSource::Pointer source_output = otb::ogr::DataSource::New(GetParameterString("out"), otb::ogr::DataSource::Modes::Read);
        layer                                       = source_output->GetLayer(0);
        updateMode                                  = true;
318
319
320
321
        otbAppLogINFO("Update input vector data.");

        // fill temporary buffer for the transfer
        otb::ogr::Layer inputLayer = layer;
322
323
        layer                      = buffer->CopyLayer(inputLayer, std::string("Buffer"));
        // close input data source
324
325
        source_output->Clear();
        // Re-open input data source in update mode
326
327
        output = otb::ogr::DataSource::New(GetParameterString("out"), otb::ogr::DataSource::Modes::Update_LayerUpdate);
      }
328
      else
329
330
      {
        otbAppLogFATAL(<< "Error when creating the output file" << GetParameterString("mode") << " : unsupported writing mode type");
331
      }
332
    }
333

334
335
    otb::ogr::Layer outLayer = output->GetLayer(0);
    OGRErr          errStart = outLayer.ogr().StartTransaction();
336
337

    if (errStart != OGRERR_NONE)
338
    {
339
      otbAppLogFATAL(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
340
    }
341
342
343

    // Build the list of output fields
    std::vector<std::string> outFields;
344
345
346
    if (GetParameterString("featout") == "prefix")
    {
      std::string        prefix = GetParameterString("featout.prefix.name");
347
      std::ostringstream oss;
348
349
      for (unsigned int i = 0; i < m_Model->GetDimension(); i++)
      {
350
        oss.str(prefix);
351
        oss.seekp(0, std::ios_base::end);
352
353
354
        oss << i;
        outFields.push_back(oss.str());
      }
355
356
357
    }
    else if (GetParameterString("featout") == "list")
    {
358
359
360
      outFields = GetParameterStringList("featout.list.names");
      if (outFields.size() != m_Model->GetDimension())
      {
361
        otbAppLogFATAL(<< "Wrong number of output field names, expected " << m_Model->GetDimension() << " , got " << outFields.size());
362
      }
363
364
365
366
367
    }
    else
    {
      otbAppLogFATAL(<< "Unsupported output feature mode : " << GetParameterString("featout"));
    }
368
369

    // Add the field of prediction in the output layer if field not exist
370
371
372
373
374
    for (unsigned int i = 0; i < outFields.size(); i++)
    {
      OGRFeatureDefn& layerDefn = outLayer.GetLayerDefn();
      int             idx       = layerDefn.GetFieldIndex(outFields[i].c_str());

375
      if (idx >= 0)
376
      {
377
        if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
378
379
          otbAppLogFATAL("Field name " << outFields[i] << " already exists with a different type!");
      }
380
      else
381
382
      {
        OGRFieldDefn   predictedField(outFields[i].c_str(), OFTReal);
383
384
385
        ogr::FieldDefn predictedFieldDef(predictedField);
        outLayer.CreateField(predictedFieldDef);
      }
386
    }
387
388

    // Fill output layer
389
390
391
392
393
    unsigned int count = 0;
    it                 = layer.cbegin();
    itEnd              = layer.cend();
    for (; it != itEnd; ++it, ++count)
    {
394
395
      ogr::Feature dstFeature(outLayer.GetLayerDefn());

396
      dstFeature.SetFrom(*it, TRUE);
397
398
      dstFeature.SetFID(it->GetFID());

399
400
      for (std::size_t i = 0; i < outFields.size(); ++i)
      {
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
        switch (dstFeature[outFields[i]].GetType())
        {
        case OFTInteger:
          dstFeature[outFields[i]].SetValue<int>(target->GetMeasurementVector(count)[0]);
          break;
        case OFTInteger64:
          dstFeature[outFields[i]].SetValue<int>(target->GetMeasurementVector(count)[0]);
          break;
        case OFTReal:
          dstFeature[outFields[i]].SetValue<double>(target->GetMeasurementVector(count)[0]);
          break;
        case OFTString:
          dstFeature[outFields[i]].SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
          break;
        default:
          itkExceptionMacro(<< "incorrect field type: " << dstFeature[outFields[i]].GetType() << ".");
        }
418
      }
419
      if (updateMode)
420
      {
421
        outLayer.SetFeature(dstFeature);
422
      }
423
      else
424
      {
425
426
        outLayer.CreateFeature(dstFeature);
      }
427
    }
428

429
430
    if (outLayer.ogr().TestCapability("Transactions"))
    {
431
432
      const OGRErr errCommitX = outLayer.ogr().CommitTransaction();
      if (errCommitX != OGRERR_NONE)
433
434
      {
        otbAppLogFATAL(<< "Unable to commit transaction for OGR layer " << outLayer.ogr().GetName() << ".");
435
      }
436
    }
437
438
    output->SyncToDisk();
    clock_t toc = clock();
439
440
    otbAppLogINFO("Elapsed: " << ((double)(toc - tic) / CLOCKS_PER_SEC) << " seconds.");
  }
441
442

  ModelPointerType m_Model;
443
};
444
445
446
447

} // end of namespace Wrapper
} // end of namespace otb

448
OTB_APPLICATION_EXPORT(otb::Wrapper::VectorDimensionalityReduction)