From 3369f28ec53757e9d7f5e406791e792e08861246 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr>
Date: Tue, 30 Jul 2019 14:13:13 +0200
Subject: [PATCH] ENH: use template parameter RegressionMode on the ML model

---
 .../app/otbVectorClassifier.cxx               | 21 ++++++++++++++-
 .../app/otbVectorRegression.cxx               | 21 ++++++++++++++-
 .../include/otbVectorPrediction.h             | 27 +++++++------------
 3 files changed, 50 insertions(+), 19 deletions(-)

diff --git a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx
index d6f5e25e42..da553cfae0 100644
--- a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx
+++ b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx
@@ -25,7 +25,26 @@ namespace otb
 namespace Wrapper
 {
 
-typedef VectorPrediction<false, float, unsigned int> VectorClassifier;
+using VectorClassifier = VectorPrediction<false, float, unsigned int>;
+
+template<>
+void
+VectorClassifier
+::CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer)
+{
+  int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
+  if (idx >= 0)
+  {
+    if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger)
+      itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!");
+  }
+  else
+  {
+    OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger);
+    ogr::FieldDefn predictedFieldDef(predictedField);
+    outLayer.CreateField(predictedFieldDef);
+  }
+}
 
 }
 }
diff --git a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx
index 00670adfe2..bba74a3152 100644
--- a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx
+++ b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx
@@ -25,7 +25,26 @@ namespace otb
 namespace Wrapper
 {
 
-typedef VectorPrediction<true, float, float>         VectorRegression;
+using VectorRegression = VectorPrediction<true, float, float>;
+
+template<>
+void
+VectorRegression
+::CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer)
+{
+  int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
+  if (idx >= 0)
+  {
+    if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
+      itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!");
+  }
+  else
+  {
+    OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTReal);
+    ogr::FieldDefn predictedFieldDef(predictedField);
+    outLayer.CreateField(predictedFieldDef);
+  }
+}
 
 }
 }
diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.h b/Modules/Applications/AppClassification/include/otbVectorPrediction.h
index a138572574..dfa537be48 100644
--- a/Modules/Applications/AppClassification/include/otbVectorPrediction.h
+++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.h
@@ -194,6 +194,10 @@ private:
     }
   }
 
+  /** Create the prediction field in the output layer, this template method should be specialized
+   * to create the right type of field (e.g. OGRInteger or OGRReal) */ 
+  void CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer);
+
   void DoExecute() override
   {
     clock_t tic = clock();
@@ -275,6 +279,8 @@ private:
       otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
       }
 
+    m_Model->SetRegressionMode(RegressionMode);
+
     m_Model->Load(GetParameterString("model"));
     otbAppLogINFO("Model loaded");
 
@@ -342,26 +348,16 @@ private:
       itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
       }
 
-    // Add the field of prediction in the output layer if field not exist
     OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
-    int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
-    if (idx >= 0)
-      {
-      if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger)
-        itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!");
-      }
-    else
-      {
-      OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger);
-      ogr::FieldDefn predictedFieldDef(predictedField);
-      outLayer.CreateField(predictedFieldDef);
-      }
+
+    // Add the field of prediction in the output layer if field not exist
+    CreatePredictionField(layerDefn, outLayer);
 
     // Add confidence field in the output layer
     std::string confFieldName("confidence");
     if (computeConfidenceMap)
       {
-      idx = layerDefn.GetFieldIndex(confFieldName.c_str());
+      int idx = layerDefn.GetFieldIndex(confFieldName.c_str());
       if (idx >= 0)
         {
         if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
@@ -435,9 +431,6 @@ private:
   ModelPointerType m_Model;
 };
 
-typedef VectorPrediction<false, float, unsigned int> VectorClassifier;
-typedef VectorPrediction<true, float, float>         VectorRegression;
-
 }
 }
 
-- 
GitLab