From 84514586f43a879ec3fd33a52bb557263bddc0f5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr>
Date: Mon, 29 Jul 2019 17:59:22 +0200
Subject: [PATCH] ENH: make VectorClassifier template

---
 .../app/otbVectorClassifier.cxx               | 408 +---------------
 .../include/otbVectorPrediction.h             | 444 ++++++++++++++++++
 2 files changed, 446 insertions(+), 406 deletions(-)
 create mode 100644 Modules/Applications/AppClassification/include/otbVectorPrediction.h

diff --git a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx
index 1f0401aeec..d6f5e25e42 100644
--- a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx
+++ b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx
@@ -18,418 +18,14 @@
  * limitations under the License.
  */
 
-#include "otbWrapperApplication.h"
-#include "otbWrapperApplicationFactory.h"
-
-#include "otbOGRDataSourceWrapper.h"
-#include "otbOGRFeatureWrapper.h"
-
-#include "itkVariableLengthVector.h"
-#include "otbStatisticsXMLFileReader.h"
-
-#include "itkListSample.h"
-#include "otbShiftScaleSampleListFilter.h"
-
-#include "otbMachineLearningModelFactory.h"
-
-#include "otbMachineLearningModel.h"
-
-#include <time.h>
+#include "otbVectorPrediction.h"
 
 namespace otb
 {
 namespace Wrapper
 {
 
-/** Utility function to negate std::isalnum */
-bool IsNotAlphaNum(char c)
-  {
-  return !std::isalnum(c);
-  }
-
-class VectorClassifier : public Application
-{
-public:
-  /** Standard class typedefs. */
-  typedef VectorClassifier              Self;
-  typedef Application                   Superclass;
-  typedef itk::SmartPointer<Self>       Pointer;
-  typedef itk::SmartPointer<const Self> ConstPointer;
-
-  /** Standard macro */
-  itkNewMacro(Self);
-
-  itkTypeMacro(Self, Application)
-
-  /** Filters typedef */
-  typedef float                                         ValueType;
-  typedef unsigned int                                  LabelType;
-  typedef itk::FixedArray<LabelType,1>                  LabelSampleType;
-  typedef itk::Statistics::ListSample<LabelSampleType>  LabelListSampleType;
-
-  typedef otb::MachineLearningModel<ValueType,LabelType>          MachineLearningModelType;
-  typedef otb::MachineLearningModelFactory<ValueType, LabelType>  MachineLearningModelFactoryType;
-  typedef MachineLearningModelType::Pointer                       ModelPointerType;
-  typedef MachineLearningModelType::ConfidenceListSampleType      ConfidenceListSampleType;
-
-  /** Statistics Filters typedef */
-  typedef itk::VariableLengthVector<ValueType>                    MeasurementType;
-  typedef otb::StatisticsXMLFileReader<MeasurementType>           StatisticsReader;
-
-  typedef itk::VariableLengthVector<ValueType>                    InputSampleType;
-  typedef itk::Statistics::ListSample<InputSampleType>            ListSampleType;
-  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
-
-  ~VectorClassifier() override
-    {
-    MachineLearningModelFactoryType::CleanFactories();
-    }
-
-private:
-  void DoInit() override
-  {
-    SetName("VectorClassifier");
-    SetDescription("Performs a classification of the input vector data according to a model file.");
-
-    SetDocAuthors("OTB-Team");
-    SetDocLongDescription("This application performs a vector data classification "
-      "based on a model file produced by the TrainVectorClassifier application."
-      "Features of the vector data output will contain the class labels decided by the classifier "
-      "(maximal class label = 65535). \n"
-      "There are two modes: \n"
-        "1) Update mode: add of the 'cfield' field containing the predicted class in the input file. \n"
-        "2) Write mode: copies the existing fields of the input file to the output file "
-           " and add the 'cfield' field containing the predicted class. \n"
-      "If you have declared the output file, the write mode applies. "
-      "Otherwise, the input file update mode will be applied.");
-
-    SetDocLimitations("Shapefiles are supported, but the SQLite format is only supported in update mode.");
-    SetDocSeeAlso("TrainVectorClassifier");
-    AddDocTag(Tags::Learning);
-
-    AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data");
-    SetParameterDescription("in","The input vector data file to classify.");
-
-    AddParameter(ParameterType_InputFilename, "instat", "Statistics file");
-    SetParameterDescription("instat", "A XML file containing mean and standard deviation to center"
-      "and reduce samples before classification, produced by ComputeImagesStatistics application.");
-    MandatoryOff("instat");
-
-    AddParameter(ParameterType_InputFilename, "model", "Model file");
-    SetParameterDescription("model", "Model file produced by TrainVectorClassifier application.");
-
-    AddParameter(ParameterType_String,"cfield","Field class");
-    SetParameterDescription("cfield","Field containing the predicted class."
-      "Only geometries with this field available will be taken into account.\n"
-      "The field is added either in the input file (if 'out' off) or in the output file.\n"
-      "Caution, the 'cfield' must not exist in the input file if you are updating the file.");
-    SetParameterString("cfield","predicted");
-
-    AddParameter(ParameterType_ListView, "feat", "Field names to be calculated");
-    SetParameterDescription("feat","List of field names in the input vector data used as features for training. "
-      "Put the same field names as the TrainVectorClassifier application.");
-
-    AddParameter(ParameterType_Bool, "confmap",  "Confidence map");
-    SetParameterDescription( "confmap", "Confidence map of the produced classification. The confidence index depends on the model: \n\n"
-      "* LibSVM: difference between the two highest probabilities (needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n"
-      "* Boost: sum of votes\n"
-      "* DecisionTree: (not supported)\n"
-      "* KNearestNeighbors: number of neighbors with the same label\n"
-      "* NeuralNetwork: difference between the two highest responses\n"
-      "* NormalBayes: (not supported)\n"
-      "* RandomForest: Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n"
-      "* SVM: distance to margin (only works for 2-class models)\n");
-
-    AddParameter(ParameterType_OutputFilename, "out", "Output vector data file containing class labels");
-    SetParameterDescription("out","Output vector data file storing sample values (OGR format)."
-      "If not given, the input vector data file is updated.");
-    MandatoryOff("out");
-
-    // Doc example parameter settings
-    SetDocExampleParameterValue("in", "vectorData.shp");
-    SetDocExampleParameterValue("instat", "meanVar.xml");
-    SetDocExampleParameterValue("model", "svmModel.svm");
-    SetDocExampleParameterValue("out", "vectorDataLabeledVector.shp");
-    SetDocExampleParameterValue("feat", "perimeter  area  width");
-    SetDocExampleParameterValue("cfield", "predicted");
-
-    SetOfficialDocLink();
-  }
-
-  void DoUpdateParameters() override
-  {
-    if ( HasValue("in") )
-    {
-      std::string shapefile = GetParameterString("in");
-
-      otb::ogr::DataSource::Pointer ogrDS;
-
-      OGRSpatialReference oSRS("");
-      std::vector<std::string> options;
-
-      ogrDS = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
-      otb::ogr::Layer layer = ogrDS->GetLayer(0);
-      OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
-
-      ClearChoices("feat");
-
-      for(int iField=0; iField< layerDefn.GetFieldCount(); iField++)
-      {
-        std::string item = layerDefn.GetFieldDefn(iField)->GetNameRef();
-        std::string key(item);
-        key.erase( std::remove_if(key.begin(),key.end(),IsNotAlphaNum), key.end());
-        std::transform(key.begin(), key.end(), key.begin(), tolower);
-
-        OGRFieldType fieldType = layerDefn.GetFieldDefn(iField)->GetType();
-        if(fieldType == OFTInteger ||  fieldType == OFTInteger64 || fieldType == OFTReal)
-          {
-          std::string tmpKey="feat."+key;
-          AddChoice(tmpKey,item);
-          }
-      }
-    }
-  }
-
-  void DoExecute() override
-  {
-    clock_t tic = clock();
-
-    std::string shapefile = GetParameterString("in");
-
-    otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
-    otb::ogr::Layer layer = source->GetLayer(0);
-
-    ListSampleType::Pointer input = ListSampleType::New();
-
-    const int nbFeatures = GetSelectedItems("feat").size();
-    input->SetMeasurementVectorSize(nbFeatures);
-  
-    otb::ogr::Layer::const_iterator it = layer.cbegin();
-    otb::ogr::Layer::const_iterator itEnd = layer.cend();
-    for( ; it!=itEnd ; ++it)
-      {
-      MeasurementType mv;
-      mv.SetSize(nbFeatures);
-      for(int idx=0; idx < nbFeatures; ++idx)
-        {
-        // Beware that itemIndex differs from ogr layer field index
-        unsigned int itemIndex = GetSelectedItems("feat")[idx];
-        std::string fieldName = GetChoiceNames( "feat" )[itemIndex];
-        switch ((*it)[fieldName].GetType())
-        {
-        case OFTInteger:
-          mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
-          break;
-        case OFTInteger64:
-          mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
-          break;
-        case OFTReal:
-          mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<double>());
-          break;
-        default:
-          itkExceptionMacro(<< "incorrect field type: " << (*it)[fieldName].GetType() << ".");
-        }
-        
-        
-        }
-      input->PushBack(mv);
-      }
-
-    // Statistics for shift/scale
-    MeasurementType meanMeasurementVector;
-    MeasurementType stddevMeasurementVector;
-    if (HasValue("instat") && IsParameterEnabled("instat"))
-      {
-      StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
-      std::string XMLfile = GetParameterString("instat");
-      statisticsReader->SetFileName(XMLfile);
-      meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
-      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
-      }
-    else
-      {
-      meanMeasurementVector.SetSize(nbFeatures);
-      meanMeasurementVector.Fill(0.);
-      stddevMeasurementVector.SetSize(nbFeatures);
-      stddevMeasurementVector.Fill(1.);
-      }
-
-    ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
-    trainingShiftScaleFilter->SetInput(input);
-    trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
-    trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
-    trainingShiftScaleFilter->Update();
-    otbAppLogINFO("mean used: " << meanMeasurementVector);
-    otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
-
-    otbAppLogINFO("Loading model");
-    m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
-                                                MachineLearningModelFactoryType::ReadMode);
-
-    if (m_Model.IsNull())
-      {
-      otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
-      }
-
-    m_Model->Load(GetParameterString("model"));
-    otbAppLogINFO("Model loaded");
-
-    ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput();
-
-    ConfidenceListSampleType::Pointer quality;
-
-    bool computeConfidenceMap(GetParameterInt("confmap") && m_Model->HasConfidenceIndex() 
-                              && !m_Model->GetRegressionMode());
-
-    if (!m_Model->HasConfidenceIndex() && GetParameterInt("confmap"))
-      {
-      otbAppLogWARNING("Confidence map requested but the classifier doesn't support it!");
-      }
-
-    LabelListSampleType::Pointer target;
-    if (computeConfidenceMap)
-      {
-      quality = ConfidenceListSampleType::New();
-      target = m_Model->PredictBatch(listSample, quality);
-      }
-      else
-      {
-      target = m_Model->PredictBatch(listSample);
-      }
-
-    ogr::DataSource::Pointer output;
-    ogr::DataSource::Pointer buffer = ogr::DataSource::New();
-    bool updateMode = false;
-    if (IsParameterEnabled("out") && HasValue("out"))
-      {
-      // Create new OGRDataSource
-      output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
-      otb::ogr::Layer newLayer = output->CreateLayer(
-        GetParameterString("out"),
-        const_cast<OGRSpatialReference*>(layer.GetSpatialRef()),
-        layer.GetGeomType());
-      // Copy existing fields
-      OGRFeatureDefn &inLayerDefn = layer.GetLayerDefn();
-      for (int k=0 ; k<inLayerDefn.GetFieldCount() ; k++)
-        {
-        OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
-        newLayer.CreateField(fieldDefn);
-        }
-      }
-    else
-      {
-      // Update mode
-      updateMode = true;
-      otbAppLogINFO("Update input vector data.");
-      // fill temporary buffer for the transfer
-      otb::ogr::Layer inputLayer = layer;
-      layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
-      // close input data source
-      source->Clear();
-      // Re-open input data source in update mode
-      output = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Update_LayerUpdate);
-      }
-
-    otb::ogr::Layer outLayer = output->GetLayer(0);
-
-    OGRErr errStart = outLayer.ogr().StartTransaction();
-    if (errStart != OGRERR_NONE)
-      {
-      itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
-      }
-
-    // Add the field of prediction in the output layer if field not exist
-    OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
-    int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
-    if (idx >= 0)
-      {
-      if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger)
-        itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!");
-      }
-    else
-      {
-      OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger);
-      ogr::FieldDefn predictedFieldDef(predictedField);
-      outLayer.CreateField(predictedFieldDef);
-      }
-
-    // Add confidence field in the output layer
-    std::string confFieldName("confidence");
-    if (computeConfidenceMap)
-      {
-      idx = layerDefn.GetFieldIndex(confFieldName.c_str());
-      if (idx >= 0)
-        {
-        if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
-          itkExceptionMacro("Field name "<< confFieldName << " already exists with a different type!");
-        }
-      else
-        {
-        OGRFieldDefn confidenceField(confFieldName.c_str(), OFTReal);
-        confidenceField.SetWidth(confidenceField.GetWidth());
-        confidenceField.SetPrecision(confidenceField.GetPrecision());
-        ogr::FieldDefn confFieldDefn(confidenceField);
-        outLayer.CreateField(confFieldDefn);
-        }
-      }
-
-    // Fill output layer
-    unsigned int count=0;
-    std::string classfieldname = GetParameterString("cfield");
-    it = layer.cbegin();
-    itEnd = layer.cend();
-    for( ; it!=itEnd ; ++it, ++count)
-      {
-      ogr::Feature dstFeature(outLayer.GetLayerDefn());
-      dstFeature.SetFrom( *it , TRUE);
-      dstFeature.SetFID(it->GetFID());
-      switch (dstFeature[classfieldname].GetType())
-        {
-        case OFTInteger:
-          dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
-          break;
-        case OFTInteger64:
-          dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
-          break;
-        case OFTReal:
-          dstFeature[classfieldname].SetValue<double>(target->GetMeasurementVector(count)[0]);
-          break;
-        case OFTString:
-          dstFeature[classfieldname].SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
-          break;
-        default:
-          itkExceptionMacro(<< "incorrect field type: " << dstFeature[classfieldname].GetType() << ".");
-        }
-      if (computeConfidenceMap)
-        dstFeature[confFieldName].SetValue<double>(quality->GetMeasurementVector(count)[0]);
-      if (updateMode)
-        {
-        outLayer.SetFeature(dstFeature);
-        }
-      else
-        {
-        outLayer.CreateFeature(dstFeature);
-        }
-      }
-
-    if(outLayer.ogr().TestCapability("Transactions"))
-      {
-      const OGRErr errCommitX = outLayer.ogr().CommitTransaction();
-      if (errCommitX != OGRERR_NONE)
-        {
-        itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << outLayer.ogr().GetName() << ".");
-        }
-      }
-
-    output->SyncToDisk();
-
-    clock_t toc = clock();
-    otbAppLogINFO( "Elapsed: "<< ((double)(toc - tic) / CLOCKS_PER_SEC)<<" seconds.");
-
-  }
-
-  ModelPointerType m_Model;
-};
+typedef VectorPrediction<false, float, unsigned int> VectorClassifier;
 
 }
 }
diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.h b/Modules/Applications/AppClassification/include/otbVectorPrediction.h
new file mode 100644
index 0000000000..a138572574
--- /dev/null
+++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.h
@@ -0,0 +1,444 @@
+/*
+ * Copyright (C) 2005-2019 Centre National d'Etudes Spatiales (CNES)
+ *
+ * This file is part of Orfeo Toolbox
+ *
+ *     https://www.orfeo-toolbox.org/
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef otbVectorPrediction_h
+#define otbVectorPrediction_h
+
+#include "otbWrapperApplication.h"
+#include "otbWrapperApplicationFactory.h"
+
+#include "otbOGRDataSourceWrapper.h"
+#include "otbOGRFeatureWrapper.h"
+
+#include "itkVariableLengthVector.h"
+#include "otbStatisticsXMLFileReader.h"
+
+#include "itkListSample.h"
+#include "otbShiftScaleSampleListFilter.h"
+
+#include "otbMachineLearningModelFactory.h"
+
+#include "otbMachineLearningModel.h"
+
+#include <time.h>
+
+namespace otb
+{
+namespace Wrapper
+{
+
+template <bool RegressionMode, class ValueType, class LabelType>
+class VectorPrediction : public Application
+{
+public:
+  /** Standard class typedefs. */
+  typedef VectorPrediction              Self;
+  typedef Application                   Superclass;
+  typedef itk::SmartPointer<Self>       Pointer;
+  typedef itk::SmartPointer<const Self> ConstPointer;
+
+  /** Standard macro */
+  itkNewMacro(Self);
+
+  itkTypeMacro(Self, Application)
+
+  /** Filters typedef */
+  //typedef float                                         ValueType;
+  //typedef unsigned int                                  LabelType;
+  typedef itk::FixedArray<LabelType,1>                  LabelSampleType;
+  typedef itk::Statistics::ListSample<LabelSampleType>  LabelListSampleType;
+
+  typedef otb::MachineLearningModel<ValueType,LabelType>          MachineLearningModelType;
+  typedef otb::MachineLearningModelFactory<ValueType, LabelType>  MachineLearningModelFactoryType;
+  typedef typename MachineLearningModelType::Pointer                       ModelPointerType;
+  typedef typename MachineLearningModelType::ConfidenceListSampleType      ConfidenceListSampleType;
+
+  /** Statistics Filters typedef */
+  typedef itk::VariableLengthVector<ValueType>                    MeasurementType;
+  typedef otb::StatisticsXMLFileReader<MeasurementType>           StatisticsReader;
+
+  typedef itk::VariableLengthVector<ValueType>                    InputSampleType;
+  typedef itk::Statistics::ListSample<InputSampleType>            ListSampleType;
+  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
+
+  ~VectorPrediction() override
+    {
+    MachineLearningModelFactoryType::CleanFactories();
+    }
+
+private:
+  /** Utility function to negate std::isalnum */
+  static bool IsNotAlphaNum(char c)
+  {
+  return !std::isalnum(c);
+  }
+  
+  void DoInit() override
+  {
+    SetName("VectorClassifier");
+    SetDescription("Performs a classification of the input vector data according to a model file.");
+
+    SetDocAuthors("OTB-Team");
+    SetDocLongDescription("This application performs a vector data classification "
+      "based on a model file produced by the TrainVectorClassifier application."
+      "Features of the vector data output will contain the class labels decided by the classifier "
+      "(maximal class label = 65535). \n"
+      "There are two modes: \n"
+        "1) Update mode: add of the 'cfield' field containing the predicted class in the input file. \n"
+        "2) Write mode: copies the existing fields of the input file to the output file "
+           " and add the 'cfield' field containing the predicted class. \n"
+      "If you have declared the output file, the write mode applies. "
+      "Otherwise, the input file update mode will be applied.");
+
+    SetDocLimitations("Shapefiles are supported, but the SQLite format is only supported in update mode.");
+    SetDocSeeAlso("TrainVectorClassifier");
+    AddDocTag(Tags::Learning);
+
+    AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data");
+    SetParameterDescription("in","The input vector data file to classify.");
+
+    AddParameter(ParameterType_InputFilename, "instat", "Statistics file");
+    SetParameterDescription("instat", "A XML file containing mean and standard deviation to center"
+      "and reduce samples before classification, produced by ComputeImagesStatistics application.");
+    MandatoryOff("instat");
+
+    AddParameter(ParameterType_InputFilename, "model", "Model file");
+    SetParameterDescription("model", "Model file produced by TrainVectorClassifier application.");
+
+    AddParameter(ParameterType_String,"cfield","Field class");
+    SetParameterDescription("cfield","Field containing the predicted class."
+      "Only geometries with this field available will be taken into account.\n"
+      "The field is added either in the input file (if 'out' off) or in the output file.\n"
+      "Caution, the 'cfield' must not exist in the input file if you are updating the file.");
+    SetParameterString("cfield","predicted");
+
+    AddParameter(ParameterType_ListView, "feat", "Field names to be calculated");
+    SetParameterDescription("feat","List of field names in the input vector data used as features for training. "
+      "Put the same field names as the TrainVectorClassifier application.");
+
+    AddParameter(ParameterType_Bool, "confmap",  "Confidence map");
+    SetParameterDescription( "confmap", "Confidence map of the produced classification. The confidence index depends on the model: \n\n"
+      "* LibSVM: difference between the two highest probabilities (needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n"
+      "* Boost: sum of votes\n"
+      "* DecisionTree: (not supported)\n"
+      "* KNearestNeighbors: number of neighbors with the same label\n"
+      "* NeuralNetwork: difference between the two highest responses\n"
+      "* NormalBayes: (not supported)\n"
+      "* RandomForest: Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n"
+      "* SVM: distance to margin (only works for 2-class models)\n");
+
+    AddParameter(ParameterType_OutputFilename, "out", "Output vector data file containing class labels");
+    SetParameterDescription("out","Output vector data file storing sample values (OGR format)."
+      "If not given, the input vector data file is updated.");
+    MandatoryOff("out");
+
+    // Doc example parameter settings
+    SetDocExampleParameterValue("in", "vectorData.shp");
+    SetDocExampleParameterValue("instat", "meanVar.xml");
+    SetDocExampleParameterValue("model", "svmModel.svm");
+    SetDocExampleParameterValue("out", "vectorDataLabeledVector.shp");
+    SetDocExampleParameterValue("feat", "perimeter  area  width");
+    SetDocExampleParameterValue("cfield", "predicted");
+
+    SetOfficialDocLink();
+  }
+
+  void DoUpdateParameters() override
+  {
+    if ( HasValue("in") )
+    {
+      std::string shapefile = GetParameterString("in");
+
+      otb::ogr::DataSource::Pointer ogrDS;
+
+      OGRSpatialReference oSRS("");
+      std::vector<std::string> options;
+
+      ogrDS = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
+      otb::ogr::Layer layer = ogrDS->GetLayer(0);
+      OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
+
+      ClearChoices("feat");
+
+      for(int iField=0; iField< layerDefn.GetFieldCount(); iField++)
+      {
+        std::string item = layerDefn.GetFieldDefn(iField)->GetNameRef();
+        std::string key(item);
+        key.erase( std::remove_if(key.begin(),key.end(),IsNotAlphaNum), key.end());
+        std::transform(key.begin(), key.end(), key.begin(), tolower);
+
+        OGRFieldType fieldType = layerDefn.GetFieldDefn(iField)->GetType();
+        if(fieldType == OFTInteger ||  fieldType == OFTInteger64 || fieldType == OFTReal)
+          {
+          std::string tmpKey="feat."+key;
+          AddChoice(tmpKey,item);
+          }
+      }
+    }
+  }
+
+  void DoExecute() override
+  {
+    clock_t tic = clock();
+
+    std::string shapefile = GetParameterString("in");
+
+    otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
+    otb::ogr::Layer layer = source->GetLayer(0);
+
+    typename ListSampleType::Pointer input = ListSampleType::New();
+
+    const int nbFeatures = GetSelectedItems("feat").size();
+    input->SetMeasurementVectorSize(nbFeatures);
+  
+    otb::ogr::Layer::const_iterator it = layer.cbegin();
+    otb::ogr::Layer::const_iterator itEnd = layer.cend();
+    for( ; it!=itEnd ; ++it)
+      {
+      MeasurementType mv;
+      mv.SetSize(nbFeatures);
+      for(int idx=0; idx < nbFeatures; ++idx)
+        {
+        // Beware that itemIndex differs from ogr layer field index
+        unsigned int itemIndex = GetSelectedItems("feat")[idx];
+        std::string fieldName = GetChoiceNames( "feat" )[itemIndex];
+        switch ((*it)[fieldName].GetType())
+        {
+        case OFTInteger:
+          mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
+          break;
+        case OFTInteger64:
+          mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
+          break;
+        case OFTReal:
+          mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<double>());
+          break;
+        default:
+          itkExceptionMacro(<< "incorrect field type: " << (*it)[fieldName].GetType() << ".");
+        }
+        
+        
+        }
+      input->PushBack(mv);
+      }
+
+    // Statistics for shift/scale
+    MeasurementType meanMeasurementVector;
+    MeasurementType stddevMeasurementVector;
+    if (HasValue("instat") && IsParameterEnabled("instat"))
+      {
+      typename StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
+      std::string XMLfile = GetParameterString("instat");
+      statisticsReader->SetFileName(XMLfile);
+      meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
+      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
+      }
+    else
+      {
+      meanMeasurementVector.SetSize(nbFeatures);
+      meanMeasurementVector.Fill(0.);
+      stddevMeasurementVector.SetSize(nbFeatures);
+      stddevMeasurementVector.Fill(1.);
+      }
+
+    typename ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
+    trainingShiftScaleFilter->SetInput(input);
+    trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
+    trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
+    trainingShiftScaleFilter->Update();
+    otbAppLogINFO("mean used: " << meanMeasurementVector);
+    otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
+
+    otbAppLogINFO("Loading model");
+    m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
+                                                MachineLearningModelFactoryType::ReadMode);
+
+    if (m_Model.IsNull())
+      {
+      otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
+      }
+
+    m_Model->Load(GetParameterString("model"));
+    otbAppLogINFO("Model loaded");
+
+    typename ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput();
+
+    typename ConfidenceListSampleType::Pointer quality;
+
+    bool computeConfidenceMap(GetParameterInt("confmap") && m_Model->HasConfidenceIndex() 
+                              && !m_Model->GetRegressionMode());
+
+    if (!m_Model->HasConfidenceIndex() && GetParameterInt("confmap"))
+      {
+      otbAppLogWARNING("Confidence map requested but the classifier doesn't support it!");
+      }
+
+    typename LabelListSampleType::Pointer target;
+    if (computeConfidenceMap)
+      {
+      quality = ConfidenceListSampleType::New();
+      target = m_Model->PredictBatch(listSample, quality);
+      }
+      else
+      {
+      target = m_Model->PredictBatch(listSample);
+      }
+
+    ogr::DataSource::Pointer output;
+    ogr::DataSource::Pointer buffer = ogr::DataSource::New();
+    bool updateMode = false;
+    if (IsParameterEnabled("out") && HasValue("out"))
+      {
+      // Create new OGRDataSource
+      output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
+      otb::ogr::Layer newLayer = output->CreateLayer(
+        GetParameterString("out"),
+        const_cast<OGRSpatialReference*>(layer.GetSpatialRef()),
+        layer.GetGeomType());
+      // Copy existing fields
+      OGRFeatureDefn &inLayerDefn = layer.GetLayerDefn();
+      for (int k=0 ; k<inLayerDefn.GetFieldCount() ; k++)
+        {
+        OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
+        newLayer.CreateField(fieldDefn);
+        }
+      }
+    else
+      {
+      // Update mode
+      updateMode = true;
+      otbAppLogINFO("Update input vector data.");
+      // fill temporary buffer for the transfer
+      otb::ogr::Layer inputLayer = layer;
+      layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
+      // close input data source
+      source->Clear();
+      // Re-open input data source in update mode
+      output = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Update_LayerUpdate);
+      }
+
+    otb::ogr::Layer outLayer = output->GetLayer(0);
+
+    OGRErr errStart = outLayer.ogr().StartTransaction();
+    if (errStart != OGRERR_NONE)
+      {
+      itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
+      }
+
+    // Add the field of prediction in the output layer if field not exist
+    OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
+    int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
+    if (idx >= 0)
+      {
+      if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger)
+        itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!");
+      }
+    else
+      {
+      OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger);
+      ogr::FieldDefn predictedFieldDef(predictedField);
+      outLayer.CreateField(predictedFieldDef);
+      }
+
+    // Add confidence field in the output layer
+    std::string confFieldName("confidence");
+    if (computeConfidenceMap)
+      {
+      idx = layerDefn.GetFieldIndex(confFieldName.c_str());
+      if (idx >= 0)
+        {
+        if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
+          itkExceptionMacro("Field name "<< confFieldName << " already exists with a different type!");
+        }
+      else
+        {
+        OGRFieldDefn confidenceField(confFieldName.c_str(), OFTReal);
+        confidenceField.SetWidth(confidenceField.GetWidth());
+        confidenceField.SetPrecision(confidenceField.GetPrecision());
+        ogr::FieldDefn confFieldDefn(confidenceField);
+        outLayer.CreateField(confFieldDefn);
+        }
+      }
+
+    // Fill output layer
+    unsigned int count=0;
+    std::string classfieldname = GetParameterString("cfield");
+    it = layer.cbegin();
+    itEnd = layer.cend();
+    for( ; it!=itEnd ; ++it, ++count)
+      {
+      ogr::Feature dstFeature(outLayer.GetLayerDefn());
+      dstFeature.SetFrom( *it , TRUE);
+      dstFeature.SetFID(it->GetFID());
+      switch (dstFeature[classfieldname].GetType())
+        {
+        case OFTInteger:
+          dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
+          break;
+        case OFTInteger64:
+          dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
+          break;
+        case OFTReal:
+          dstFeature[classfieldname].SetValue<double>(target->GetMeasurementVector(count)[0]);
+          break;
+        case OFTString:
+          dstFeature[classfieldname].SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
+          break;
+        default:
+          itkExceptionMacro(<< "incorrect field type: " << dstFeature[classfieldname].GetType() << ".");
+        }
+      if (computeConfidenceMap)
+        dstFeature[confFieldName].SetValue<double>(quality->GetMeasurementVector(count)[0]);
+      if (updateMode)
+        {
+        outLayer.SetFeature(dstFeature);
+        }
+      else
+        {
+        outLayer.CreateFeature(dstFeature);
+        }
+      }
+
+    if(outLayer.ogr().TestCapability("Transactions"))
+      {
+      const OGRErr errCommitX = outLayer.ogr().CommitTransaction();
+      if (errCommitX != OGRERR_NONE)
+        {
+        itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << outLayer.ogr().GetName() << ".");
+        }
+      }
+
+    output->SyncToDisk();
+
+    clock_t toc = clock();
+    otbAppLogINFO( "Elapsed: "<< ((double)(toc - tic) / CLOCKS_PER_SEC)<<" seconds.");
+
+  }
+
+  ModelPointerType m_Model;
+};
+
+typedef VectorPrediction<false, float, unsigned int> VectorClassifier;
+typedef VectorPrediction<true, float, float>         VectorRegression;
+
+}
+}
+
+#endif
-- 
GitLab