diff --git a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx index d6f5e25e42147a4e2436e5e9f918798f4abb0081..da553cfae01fcec339b601d76882afd383d4c006 100644 --- a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx @@ -25,7 +25,26 @@ namespace otb namespace Wrapper { -typedef VectorPrediction<false, float, unsigned int> VectorClassifier; +using VectorClassifier = VectorPrediction<false, float, unsigned int>; + +template<> +void +VectorClassifier +::CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer) +{ + int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); + if (idx >= 0) + { + if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger) + itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!"); + } + else + { + OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger); + ogr::FieldDefn predictedFieldDef(predictedField); + outLayer.CreateField(predictedFieldDef); + } +} } } diff --git a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx index 00670adfe2fe182aa47aba07c7e9a4379e0bc18b..bba74a3152f809b98a0f66bc6b1ecf1d72769054 100644 --- a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx @@ -25,7 +25,26 @@ namespace otb namespace Wrapper { -typedef VectorPrediction<true, float, float> VectorRegression; +using VectorRegression = VectorPrediction<true, float, float>; + +template<> +void +VectorRegression +::CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer) +{ + int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); + if (idx >= 0) + { + if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal) + itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!"); + } + else + { + OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTReal); + ogr::FieldDefn predictedFieldDef(predictedField); + outLayer.CreateField(predictedFieldDef); + } +} } } diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.h b/Modules/Applications/AppClassification/include/otbVectorPrediction.h index a138572574fbc61021b21cef11de8d81482091e1..dfa537be48591fef95fc8bfa6e7cb9f2575df07e 100644 --- a/Modules/Applications/AppClassification/include/otbVectorPrediction.h +++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.h @@ -194,6 +194,10 @@ private: } } + /** Create the prediction field in the output layer, this template method should be specialized + * to create the right type of field (e.g. OGRInteger or OGRReal) */ + void CreatePredictionField(OGRFeatureDefn & layerDefn, otb::ogr::Layer & outLayer); + void DoExecute() override { clock_t tic = clock(); @@ -275,6 +279,8 @@ private: otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); } + m_Model->SetRegressionMode(RegressionMode); + m_Model->Load(GetParameterString("model")); otbAppLogINFO("Model loaded"); @@ -342,26 +348,16 @@ private: itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << "."); } - // Add the field of prediction in the output layer if field not exist OGRFeatureDefn &layerDefn = layer.GetLayerDefn(); - int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); - if (idx >= 0) - { - if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger) - itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!"); - } - else - { - OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger); - ogr::FieldDefn predictedFieldDef(predictedField); - outLayer.CreateField(predictedFieldDef); - } + + // Add the field of prediction in the output layer if field not exist + CreatePredictionField(layerDefn, outLayer); // Add confidence field in the output layer std::string confFieldName("confidence"); if (computeConfidenceMap) { - idx = layerDefn.GetFieldIndex(confFieldName.c_str()); + int idx = layerDefn.GetFieldIndex(confFieldName.c_str()); if (idx >= 0) { if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal) @@ -435,9 +431,6 @@ private: ModelPointerType m_Model; }; -typedef VectorPrediction<false, float, unsigned int> VectorClassifier; -typedef VectorPrediction<true, float, float> VectorRegression; - } }