From d4a174c0b59ac06f076c79fa7799e6888a940362 Mon Sep 17 00:00:00 2001
From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr>
Date: Wed, 28 Feb 2018 18:29:59 +0100
Subject: [PATCH] ENH: implement sample augmentation as a filter

---
 .../app/otbSampleAugmentation.cxx             | 181 ++----------
 .../include/otbSampleAugmentation.h           |   4 +-
 .../include/otbSampleAugmentationFilter.h     | 168 +++++++++++
 .../include/otbSampleAugmentationFilter.txx   | 268 ++++++++++++++++++
 4 files changed, 463 insertions(+), 158 deletions(-)
 create mode 100644 Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h
 create mode 100644 Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx

diff --git a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx
index a12e3912b1..eed67c8fb8 100644
--- a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx
+++ b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx
@@ -21,7 +21,7 @@
 #include "otbWrapperApplication.h"
 #include "otbWrapperApplicationFactory.h"
 #include "otbOGRDataSourceWrapper.h"
-#include "otbSampleAugmentation.h"
+#include "otbSampleAugmentationFilter.h"
 
 namespace otb
 {
@@ -44,9 +44,9 @@ public:
   itkTypeMacro(SampleAugmentation, otb::Application);
 
   /** Filters typedef */
-  using SampleType = sampleAugmentation::SampleType;
-  using SampleVectorType = sampleAugmentation::SampleVectorType;
-
+  using FilterType = otb::SampleAugmentationFilter;
+  using SampleType = FilterType::SampleType;
+  using SampleVectorType = FilterType::SampleVectorType;
 
 private:
   SampleAugmentation() {}
@@ -220,143 +220,49 @@ private:
                          GetSelectedItems( "exclude" ));
   for(const auto& ef : excludedFeatures)
     otbAppLogINFO("Excluding feature " << ef << '\n');
-  auto inSamples = extractSamples(vectors, this->GetParameterInt("layer"),
-                                  fieldName,
-                                  this->GetParameterInt("label"),
-                                  excludedFeatures);
+
   int seed = std::time(nullptr);
   if(IsParameterEnabled("seed")) seed = this->GetParameterInt("seed");
-  SampleVectorType newSamples;
+
+
+  FilterType::Pointer filter = FilterType::New();
+  filter->SetInput(vectors);
+  filter->SetLayer(this->GetParameterInt("layer"));
+  filter->SetNumberOfSamples(this->GetParameterInt("samples"));
+  filter->SetOutputSamples(output);
+  filter->SetClassFieldName(fieldName);
+  filter->SetLabel(this->GetParameterInt("label"));
+  filter->SetExcludedFeatures(excludedFeatures);
+  filter->SetSeed(seed);
   switch (this->GetParameterInt("strategy"))
     {
     // replicate
     case 0:
     {
     otbAppLogINFO("Augmentation strategy : replicate");
-    sampleAugmentation::replicateSamples(inSamples, this->GetParameterInt("samples"),
-                                         newSamples);
+    filter->SetStrategy(FilterType::Strategy::Replicate);
     }
-    break;
+      break;
     // jitter
     case 1:
     {
     otbAppLogINFO("Augmentation strategy : jitter");
-    sampleAugmentation::jitterSamples(inSamples, this->GetParameterInt("samples"),
-                                      newSamples,
-                                      this->GetParameterFloat("strategy.jitter.stdfactor"),
-                                      seed);
+    filter->SetStrategy(FilterType::Strategy::Jitter);
+    filter->SetStdFactor(this->GetParameterFloat("stdfactor"));
     }
     break;
     case 2:
     {
     otbAppLogINFO("Augmentation strategy : smote");
-    sampleAugmentation::smote(inSamples, this->GetParameterInt("samples"),
-                              newSamples,
-                              this->GetParameterInt("strategy.smote.neighbors"),
-                              seed);
+    filter->SetStrategy(FilterType::Strategy::Smote);
+    filter->SetSmoteNeighbors(this->GetParameterInt("neighbors"));
     }
     break;
     }
-  writeSamples(vectors, output, newSamples, this->GetParameterInt("layer"),
-               fieldName,
-               this->GetParameterInt("label"),
-               excludedFeatures);
+  filter->Update();
   output->SyncToDisk();
     }
 
-/** Extracts the samples of a single class from the vector data to a
-* vector and excludes some unwanted features.
-*/
-  SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors, 
-                                  size_t layerName,
-                                  const std::string& classField, const int label,
-                                  const std::vector<std::string>& excludedFeatures = {})
-  {
-    ogr::Layer layer = vectors->GetLayer(layerName);
-    ogr::Feature feature = layer.ogr().GetNextFeature();
-    if(feature.addr() == 0)
-      {
-      otbAppLogFATAL("Layer " << layerName << " of input sample file is empty.\n");
-      }
-    int cFieldIndex = feature.ogr().GetFieldIndex( classField.c_str() );
-    if( cFieldIndex < 0 )
-      {
-      otbAppLogFATAL( "The field name for class label (" << classField
-                      << ") has not been found in the vector file " );
-      }
-
-    auto numberOfFields = feature.ogr().GetFieldCount();
-    auto excludedIds = getExcludedFeaturesIds(excludedFeatures, layer);
-    otbAppLogINFO("The vector file contains " << numberOfFields << " fields.\n");
-    SampleVectorType samples;
-    bool goesOn{feature.addr() != 0};
-    while( goesOn )
-      {
-      // Retrieve all the features for each field in the ogr layer.
-      if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label)
-        {
-
-        SampleType mv;
-        for(auto idx=0; idx<numberOfFields; ++idx)
-          {
-          if(excludedIds.find(idx) == excludedIds.cend() &&
-             isNumericField(feature, idx))
-            mv.push_back(feature.ogr().GetFieldAsDouble(idx));
-          }
-        samples.push_back(mv); 
-        }
-      feature = layer.ogr().GetNextFeature();
-      goesOn = feature.addr() != 0;
-      }
-    return samples;
-  }
-
-  void writeSamples(const ogr::DataSource::Pointer& vectors,
-                    ogr::DataSource::Pointer& output, 
-                    const SampleVectorType& samples,
-                    const size_t layerName,
-                    const std::string& classField, int label,
-                    const std::vector<std::string>& excludedFeatures = {})
-  {
-
-    auto inputLayer = vectors->GetLayer(layerName);
-    auto excludedIds = getExcludedFeaturesIds(excludedFeatures, inputLayer);
-
-    OGRSpatialReference * oSRS = nullptr;
-    if (inputLayer.GetSpatialRef())
-      {
-      oSRS = inputLayer.GetSpatialRef()->Clone();
-      }
-    OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn();
-
-    auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS, 
-                                           inputLayer.GetGeomType());
-    for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
-      {
-      OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k));
-      ogr::FieldDefn fieldDefn(originDefn);
-      outputLayer.CreateField(fieldDefn);
-      }
-
-    auto featureCount = outputLayer.GetFeatureCount(false);
-    auto templateFeature = selectTemplateFeature(inputLayer, classField, label);
-    for(const auto& sample : samples)
-         {
-         ogr::Feature dstFeature(outputLayer.GetLayerDefn());
-         dstFeature.SetFrom( templateFeature, TRUE );
-         dstFeature.SetFID(++featureCount);
-         auto sampleFieldCounter = 0;
-         for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
-           {
-           if(excludedIds.find(k) == excludedIds.cend() &&
-              isNumericField(dstFeature, k))
-             {
-             dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]);
-             }
-           }
-         outputLayer.CreateFeature( dstFeature );
-         }
-  }
 
   std::vector<std::string> GetExcludedFeatures(const std::vector<std::string>& fieldNames,
                                                const std::vector<int>& selectedIdx)
@@ -369,45 +275,8 @@ private:
       }
     return result;
   }
-  ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer, 
-                                     const std::string& classField, int label)
-  {
-    auto featureIt = inputLayer.begin();
-    bool goesOn{(*featureIt).addr() != 0};
-    while( goesOn )
-      {
-      if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label)
-        {
-        return *featureIt;
-        }
-      ++featureIt;
-      }
-    return *(inputLayer.begin());
-  }
-  std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
-                                          const ogr::Layer& inputLayer)
-  {
-    auto feature = *(inputLayer).begin();
-    std::set<size_t> excludedIds;
-    if( excludedFeatures.size() != 0)
-      {
-      for(const auto& fieldName : excludedFeatures)
-        {
-        auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
-        excludedIds.insert(idx);
-        }
-      }
-    return excludedIds;
-  }
-  bool isNumericField(const ogr::Feature& feature,
-                      const int idx)
-  {
-    OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType();
-    return (fieldType == OFTInteger 
-            || ogr::version_proxy::IsOFTInteger64( fieldType ) 
-            || fieldType == OFTReal);
-  }
-  };
+
+};
 
 } // end of namespace Wrapper
 } // end of namespace otb
diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentation.h b/Modules/Applications/AppClassification/include/otbSampleAugmentation.h
index 43fd6657a0..432dbe8a26 100644
--- a/Modules/Applications/AppClassification/include/otbSampleAugmentation.h
+++ b/Modules/Applications/AppClassification/include/otbSampleAugmentation.h
@@ -199,7 +199,7 @@ void smote(const SampleVectorType& inSamples,
       }
 }
 
-}
-}
+}//end namespaces sampleAugmentation
+}//end namespace otb
 
 #endif
diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h
new file mode 100644
index 0000000000..754e96ef33
--- /dev/null
+++ b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h
@@ -0,0 +1,168 @@
+/*
+ * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
+ *
+ * This file is part of Orfeo Toolbox
+ *
+ *     https://www.orfeo-toolbox.org/
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef otbSampleAugmentationFilter_h
+#define otbSampleAugmentationFilter_h
+
+#include "itkProcessObject.h"
+#include "otbOGRDataSourceWrapper.h"
+#include "otbSampleAugmentation.h"
+
+namespace otb
+{
+
+/** \class SampleAugmentationFilter
+This class 
+ */
+
+class ITK_EXPORT SampleAugmentationFilter :
+    public itk::ProcessObject
+{
+public:
+
+  /** typedef for the classes standards. */
+  typedef SampleAugmentationFilter                 Self;
+  typedef itk::ProcessObject                              Superclass;
+  typedef itk::SmartPointer<Self>                         Pointer;
+  typedef itk::SmartPointer<const Self>                   ConstPointer;
+
+  /** Method for management of the object factory. */
+  itkNewMacro(Self);
+
+  /** Return the name of the class. */
+  itkTypeMacro(SampleAugmentationFilter, ProcessObject);
+
+  typedef ogr::DataSource                            OGRDataSourceType;
+  typedef typename OGRDataSourceType::Pointer        OGRDataSourcePointerType;
+  typedef ogr::Layer                                 OGRLayerType;
+
+  typedef itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;
+
+  using SampleType = sampleAugmentation::SampleType;
+  using SampleVectorType = sampleAugmentation::SampleVectorType;
+
+  enum class Strategy { Replicate, Jitter, Smote };
+
+  /** Set/Get the input OGRDataSource of this process object.  */
+  using Superclass::SetInput;
+  virtual void SetInput(const OGRDataSourceType* ds);
+  const OGRDataSourceType*  GetInput(unsigned int idx);
+
+  virtual void SetOutputSamples(ogr::DataSource* data);
+
+  /** Set the Field Name in which labels will be written. (default is "class")
+   * A field "ClassFieldName" of type integer is created in the output memory layer.
+   */
+  itkSetMacro(ClassFieldName, std::string);
+  /**
+   * Return the Field name in which labels have been written.
+   */
+  itkGetMacro(ClassFieldName, std::string);
+
+
+  itkSetMacro(Layer, size_t);
+  itkGetMacro(Layer, size_t);
+  itkSetMacro(Label, int);
+  itkGetMacro(Label, int);
+  void SetStrategy(Strategy s)
+  {
+    m_Strategy = s;
+  }
+  Strategy GetStrategy() const
+  {
+    return m_Strategy;
+  }
+  itkSetMacro(NumberOfSamples, int);
+  itkGetMacro(NumberOfSamples, int);
+  void SetExcludedFeatures(const std::vector<std::string>& ef)
+  {
+    m_ExcludedFeatures = ef;
+  }
+  std::vector<std::string> GetExcludedFeatures() const
+  {
+    return m_ExcludedFeatures;
+  }
+  itkSetMacro(StdFactor, double);
+  itkGetMacro(StdFactor, double);
+  itkSetMacro(SmoteNeighbors, size_t);
+  itkGetMacro(SmoteNeighbors, size_t);
+  itkSetMacro(Seed, int);
+  itkGetMacro(Seed, int);
+/**
+   * Get the output \c ogr::DataSource which is a "memory" datasource.
+   */
+  const OGRDataSourceType * GetOutput();
+
+protected:
+  SampleAugmentationFilter();
+  ~SampleAugmentationFilter() ITK_OVERRIDE {}
+
+  /** Generate Data method*/
+  void GenerateData() ITK_OVERRIDE;
+
+  /** DataObject pointer */
+  typedef itk::DataObject::Pointer DataObjectPointer;
+
+  DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) ITK_OVERRIDE;
+  using Superclass::MakeOutput;
+
+
+  SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors, 
+                                  size_t layerName,
+                                  const std::string& classField, const int label,
+                                  const std::vector<std::string>& excludedFeatures = {});
+
+  void sampleToOGRFeatures(const ogr::DataSource::Pointer& vectors,
+                           ogr::DataSource* output, 
+                           const SampleVectorType& samples,
+                           const size_t layerName,
+                           const std::string& classField, int label,
+                           const std::vector<std::string>& excludedFeatures = {});
+
+std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
+                                        const ogr::Layer& inputLayer);
+bool isNumericField(const ogr::Feature& feature, const int idx);
+
+ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer, 
+                                   const std::string& classField, int label);
+private:
+  SampleAugmentationFilter(const Self &);  //purposely not implemented
+  void operator =(const Self&);      //purposely not implemented
+
+  std::string m_ClassFieldName;
+  size_t m_Layer;
+  int m_Label;
+  std::vector<std::string> m_ExcludedFeatures;
+  Strategy m_Strategy;
+  int m_NumberOfSamples;
+  double m_StdFactor;
+  size_t m_SmoteNeighbors;
+  int m_Seed;
+
+};
+
+
+} // end namespace otb
+
+#ifndef OTB_MANUAL_INSTANTIATION
+#include "otbSampleAugmentationFilter.txx"
+#endif
+
+#endif
diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx
new file mode 100644
index 0000000000..41590a0109
--- /dev/null
+++ b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx
@@ -0,0 +1,268 @@
+/*
+ * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
+ *
+ * This file is part of Orfeo Toolbox
+ *
+ *     https://www.orfeo-toolbox.org/
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef otbSampleAugmentationFilter_txx
+#define otbSampleAugmentationFilter_txx
+
+#include "otbSampleAugmentationFilter.h"
+#include "stdint.h" //needed for uintptr_t
+
+namespace otb
+{
+
+SampleAugmentationFilter
+::SampleAugmentationFilter() : m_ClassFieldName("class")
+{
+  this->SetNumberOfRequiredInputs(1);
+  this->SetNumberOfRequiredOutputs(1);
+  this->ProcessObject::SetNthOutput(0, this->MakeOutput(0) );
+}
+
+
+typename SampleAugmentationFilter::DataObjectPointer
+SampleAugmentationFilter
+::MakeOutput(DataObjectPointerArraySizeType itkNotUsed(idx))
+{
+  return static_cast< DataObjectPointer >(OGRDataSourceType::New().GetPointer());
+}
+
+const typename SampleAugmentationFilter::OGRDataSourceType *
+SampleAugmentationFilter
+::GetOutput()
+{
+  return static_cast< const OGRDataSourceType *>(
+    this->ProcessObject::GetOutput(0));
+}
+
+void
+SampleAugmentationFilter
+::SetInput(const otb::ogr::DataSource* ds)
+{
+  this->Superclass::SetNthInput(0, const_cast<otb::ogr::DataSource *>(ds));
+}
+
+const typename SampleAugmentationFilter::OGRDataSourceType *
+SampleAugmentationFilter
+::GetInput(unsigned int idx)
+{
+  return static_cast<const OGRDataSourceType *>
+    (this->itk::ProcessObject::GetInput(idx));
+}
+
+void
+SampleAugmentationFilter
+::SetOutputSamples(ogr::DataSource* data)
+{
+  this->SetNthOutput(0,data);
+}
+
+
+void
+SampleAugmentationFilter
+::GenerateData(void)
+{
+
+  OGRDataSourcePointerType inputDS = dynamic_cast<OGRDataSourceType*>(this->itk::ProcessObject::GetInput(0));
+  auto outputDS = static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(1));
+  auto inSamples = this->extractSamples(inputDS, m_Layer,
+                                        m_ClassFieldName,
+                                        m_Label,
+                                        m_ExcludedFeatures);
+  SampleVectorType newSamples;
+  switch (m_Strategy)
+    {
+    case Strategy::Replicate:
+    {
+    sampleAugmentation::replicateSamples(inSamples, m_NumberOfSamples,
+                                         newSamples);
+    }
+    break;
+    case Strategy::Jitter:
+    {
+    sampleAugmentation::jitterSamples(inSamples, m_NumberOfSamples,
+                                      newSamples,
+                                      m_StdFactor,
+                                      m_Seed);
+    }
+    break;
+    case Strategy::Smote:
+    {
+    sampleAugmentation::smote(inSamples, m_NumberOfSamples,
+                              newSamples,
+                              m_SmoteNeighbors,
+                              m_Seed);
+    }
+    break;
+    }
+  this->sampleToOGRFeatures(inputDS, outputDS, newSamples, m_Layer,
+                            m_ClassFieldName,
+                            m_Label,
+                            m_ExcludedFeatures);
+
+
+  //  this->SetNthOutput(0,outputDS);
+}
+
+/** Extracts the samples of a single class from the vector data to a
+* vector and excludes some unwanted features.
+*/
+SampleAugmentationFilter::SampleVectorType 
+SampleAugmentationFilter
+::extractSamples(const ogr::DataSource::Pointer vectors, 
+                 size_t layerName,
+                 const std::string& classField, const int label,
+                 const std::vector<std::string>& excludedFeatures)
+{
+  ogr::Layer layer = vectors->GetLayer(layerName);
+  ogr::Feature feature = layer.ogr().GetNextFeature();
+  if(feature.addr() == 0)
+    {
+    itkExceptionMacro("Layer " << layerName << " of input sample file is empty.\n");
+    }
+  int cFieldIndex = feature.ogr().GetFieldIndex( classField.c_str() );
+  if( cFieldIndex < 0 )
+    {
+    itkExceptionMacro( "The field name for class label (" << classField
+                       << ") has not been found in the vector file " );
+    }
+
+  auto numberOfFields = feature.ogr().GetFieldCount();
+  auto excludedIds = this->getExcludedFeaturesIds(excludedFeatures, layer);
+  SampleVectorType samples;
+  bool goesOn{feature.addr() != 0};
+  while( goesOn )
+    {
+    // Retrieve all the features for each field in the ogr layer.
+    if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label)
+      {
+
+      SampleType mv;
+      for(auto idx=0; idx<numberOfFields; ++idx)
+        {
+        if(excludedIds.find(idx) == excludedIds.cend() &&
+           this->isNumericField(feature, idx))
+          mv.push_back(feature.ogr().GetFieldAsDouble(idx));
+        }
+      samples.push_back(mv); 
+      }
+    feature = layer.ogr().GetNextFeature();
+    goesOn = feature.addr() != 0;
+    }
+  return samples;
+}
+
+void 
+SampleAugmentationFilter
+::sampleToOGRFeatures(const ogr::DataSource::Pointer& vectors,
+                      ogr::DataSource* output, 
+               const SampleAugmentationFilter::SampleVectorType& samples,
+               const size_t layerName,
+                  const std::string& classField, int label,
+                  const std::vector<std::string>& excludedFeatures)
+{
+
+  auto inputLayer = vectors->GetLayer(layerName);
+  auto excludedIds = this->getExcludedFeaturesIds(excludedFeatures, inputLayer);
+
+  OGRSpatialReference * oSRS = nullptr;
+  if (inputLayer.GetSpatialRef())
+    {
+    oSRS = inputLayer.GetSpatialRef()->Clone();
+    }
+  OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn();
+
+  auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS, 
+                                         inputLayer.GetGeomType());
+  for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
+    {
+    OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k));
+    ogr::FieldDefn fieldDefn(originDefn);
+    outputLayer.CreateField(fieldDefn);
+    }
+
+  auto featureCount = outputLayer.GetFeatureCount(false);
+  auto templateFeature = this->selectTemplateFeature(inputLayer, classField, label);
+  for(const auto& sample : samples)
+    {
+    ogr::Feature dstFeature(outputLayer.GetLayerDefn());
+    dstFeature.SetFrom( templateFeature, TRUE );
+    dstFeature.SetFID(++featureCount);
+    auto sampleFieldCounter = 0;
+    for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
+      {
+      if(excludedIds.find(k) == excludedIds.cend() &&
+         this->isNumericField(dstFeature, k))
+        {
+        dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]);
+        }
+      }
+    outputLayer.CreateFeature( dstFeature );
+    }
+}
+
+               std::set<size_t> 
+               SampleAugmentationFilter
+               ::getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
+                                        const ogr::Layer& inputLayer)
+                  {
+                    auto feature = *(inputLayer).begin();
+                    std::set<size_t> excludedIds;
+                    if( excludedFeatures.size() != 0)
+                      {
+                      for(const auto& fieldName : excludedFeatures)
+                        {
+                        auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
+                        excludedIds.insert(idx);
+                        }
+                      }
+                    return excludedIds;
+                  }
+
+               bool 
+SampleAugmentationFilter
+::isNumericField(const ogr::Feature& feature,
+                 const int idx)
+{
+  OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType();
+    return (fieldType == OFTInteger 
+            || ogr::version_proxy::IsOFTInteger64( fieldType ) 
+            || fieldType == OFTReal);
+}
+
+ogr::Feature
+SampleAugmentationFilter
+::selectTemplateFeature(const ogr::Layer& inputLayer, 
+                        const std::string& classField, int label)
+{
+  auto featureIt = inputLayer.begin();
+  bool goesOn{(*featureIt).addr() != 0};
+  while( goesOn )
+    {
+    if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label)
+      {
+      return *featureIt;
+      }
+    ++featureIt;
+    }
+  return *(inputLayer.begin());
+}
+} // end namespace otb
+
+#endif
-- 
GitLab