From 3369f28ec53757e9d7f5e406791e792e08861246 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr> Date: Tue, 30 Jul 2019 14:13:13 +0200 Subject: [PATCH] ENH: use template parameter RegressionMode on the ML model --- .../app/otbVectorClassifier.cxx | 21 ++++++++++++++- .../app/otbVectorRegression.cxx | 21 ++++++++++++++- .../include/otbVectorPrediction.h | 27 +++++++------------ 3 files changed, 50 insertions(+), 19 deletions(-) diff --git a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx index d6f5e25e42..da553cfae0 100644 --- a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx @@ -25,7 +25,26 @@ namespace otb namespace Wrapper { -typedef VectorPrediction<false, float, unsigned int> VectorClassifier; +using VectorClassifier = VectorPrediction<false, float, unsigned int>; + +template<> +void +VectorClassifier +::CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer) +{ + int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); + if (idx >= 0) + { + if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger) + itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!"); + } + else + { + OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger); + ogr::FieldDefn predictedFieldDef(predictedField); + outLayer.CreateField(predictedFieldDef); + } +} } } diff --git a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx index 00670adfe2..bba74a3152 100644 --- a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx @@ -25,7 +25,26 @@ namespace otb namespace Wrapper { -typedef VectorPrediction<true, float, float> VectorRegression; +using VectorRegression = VectorPrediction<true, float, float>; + +template<> +void +VectorRegression +::CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer) +{ + int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); + if (idx >= 0) + { + if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal) + itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!"); + } + else + { + OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTReal); + ogr::FieldDefn predictedFieldDef(predictedField); + outLayer.CreateField(predictedFieldDef); + } +} } } diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.h b/Modules/Applications/AppClassification/include/otbVectorPrediction.h index a138572574..dfa537be48 100644 --- a/Modules/Applications/AppClassification/include/otbVectorPrediction.h +++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.h @@ -194,6 +194,10 @@ private: } } + /** Create the prediction field in the output layer, this template method should be specialized + * to create the right type of field (e.g. OGRInteger or OGRReal) */ + void CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer); + void DoExecute() override { clock_t tic = clock(); @@ -275,6 +279,8 @@ private: otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); } + m_Model->SetRegressionMode(RegressionMode); + m_Model->Load(GetParameterString("model")); otbAppLogINFO("Model loaded"); @@ -342,26 +348,16 @@ private: itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << "."); } - // Add the field of prediction in the output layer if field not exist OGRFeatureDefn &layerDefn = layer.GetLayerDefn(); - int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); - if (idx >= 0) - { - if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger) - itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!"); - } - else - { - OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger); - ogr::FieldDefn predictedFieldDef(predictedField); - outLayer.CreateField(predictedFieldDef); - } + + // Add the field of prediction in the output layer if field not exist + CreatePredictionField(layerDefn, outLayer); // Add confidence field in the output layer std::string confFieldName("confidence"); if (computeConfidenceMap) { - idx = layerDefn.GetFieldIndex(confFieldName.c_str()); + int idx = layerDefn.GetFieldIndex(confFieldName.c_str()); if (idx >= 0) { if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal) @@ -435,9 +431,6 @@ private: ModelPointerType m_Model; }; -typedef VectorPrediction<false, float, unsigned int> VectorClassifier; -typedef VectorPrediction<true, float, float> VectorRegression; - } } -- GitLab