diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.h b/Modules/Applications/AppClassification/include/otbVectorPrediction.h index d895324665f20035a8361e3b443b8a0acc43a598..65fe591977d3ac690f860a6c3d74dc4fe9ec7b1f 100644 --- a/Modules/Applications/AppClassification/include/otbVectorPrediction.h +++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.h @@ -98,7 +98,37 @@ private: /** Method returning whether the confidence map should be computed, depending on the regression mode and input parameters */ bool shouldComputeConfidenceMap() const; + /** Method returning the input list sample from the input layer */ + typename ListSampleType::Pointer ReadInputListSample(otb::ogr::Layer const& layer); + + /** Normalize a list sample using the statistic file given */ + typename ListSampleType::Pointer NormalizeListSample(ListSampleType::Pointer input); + + /** Create the output DataSource, in update mode the input layer is buffered and the input + * data source is re opened in update mode. */ + otb::ogr::DataSource::Pointer CreateOutputDataSource(otb::ogr::DataSource::Pointer source, + otb::ogr::Layer & layer, + bool updateMode); + + /** Add a prediction field in the output layer if it does not exist. + * If computeConfidenceMap evaluates to true a confidence field will be + * added. */ + void AddPredictionField(otb::ogr::Layer & outLayer, + otb::ogr::Layer const& layer, + bool computeConfidenceMap); + + /** Fill the output layer with the predicted values and optionnaly the confidence */ + void FillOutputLayer(otb::ogr::Layer & outLayer, + otb::ogr::Layer const& layer, + typename LabelListSampleType::Pointer target, + typename ConfidenceListSampleType::Pointer quality, + bool updateMode, + bool computeConfidenceMap); + ModelPointerType m_Model; + + /** Name used for the confidence field */ + std::string confFieldName = "confidence"; }; } } diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.hxx b/Modules/Applications/AppClassification/include/otbVectorPrediction.hxx index 20d8d357aa814130afe66117d4dadd58d7a7137d..54b8d889d1eb1747c1edc547a2d404b1f1157705 100644 --- a/Modules/Applications/AppClassification/include/otbVectorPrediction.hxx +++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.hxx @@ -73,13 +73,10 @@ void VectorPrediction<RegressionMode>::DoUpdateParameters() } template <bool RegressionMode> -void VectorPrediction<RegressionMode>::DoExecute() +typename VectorPrediction<RegressionMode>::ListSampleType::Pointer +VectorPrediction<RegressionMode> +::ReadInputListSample(otb::ogr::Layer const& layer) { - auto shapefileName = GetParameterString("in"); - - auto source = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read); - auto layer = source->GetLayer(0); - typename ListSampleType::Pointer input = ListSampleType::New(); const int nbFeatures = GetSelectedItems("feat").size(); @@ -110,7 +107,16 @@ void VectorPrediction<RegressionMode>::DoExecute() } input->PushBack(mv); } + return input; +} +template <bool RegressionMode> +typename VectorPrediction<RegressionMode>::ListSampleType::Pointer +VectorPrediction<RegressionMode> +::NormalizeListSample(ListSampleType::Pointer input) +{ + const int nbFeatures = GetSelectedItems("feat").size(); + // Statistics for shift/scale MeasurementType meanMeasurementVector; MeasurementType stddevMeasurementVector; @@ -139,39 +145,31 @@ void VectorPrediction<RegressionMode>::DoExecute() otbAppLogINFO("standard deviation used: " << stddevMeasurementVector); otbAppLogINFO("Loading model"); - m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode); - - if (m_Model.IsNull()) - { - otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); - } - - m_Model->SetRegressionMode(RegressionMode); - - m_Model->Load(GetParameterString("model")); - otbAppLogINFO("Model loaded"); - - ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput(); - - typename ConfidenceListSampleType::Pointer quality; + + return trainingShiftScaleFilter->GetOutput(); +} - bool computeConfidenceMap = shouldComputeConfidenceMap(); - typename LabelListSampleType::Pointer target; - if (computeConfidenceMap) +template <bool RegressionMode> +otb::ogr::DataSource::Pointer +VectorPrediction<RegressionMode> +::CreateOutputDataSource(otb::ogr::DataSource::Pointer source, otb::ogr::Layer & layer, bool updateMode) +{ + ogr::DataSource::Pointer output; + ogr::DataSource::Pointer buffer = ogr::DataSource::New(); + if (updateMode) { - quality = ConfidenceListSampleType::New(); - target = m_Model->PredictBatch(listSample, quality); + // Update mode + otbAppLogINFO("Update input vector data."); + // fill temporary buffer for the transfer + otb::ogr::Layer inputLayer = layer; + layer = buffer->CopyLayer(inputLayer, std::string("Buffer")); + // close input data source + source->Clear(); + // Re-open input data source in update mode + output = otb::ogr::DataSource::New(GetParameterString("in"), otb::ogr::DataSource::Modes::Update_LayerUpdate); } else - { - target = m_Model->PredictBatch(listSample); - } - - ogr::DataSource::Pointer output; - ogr::DataSource::Pointer buffer = ogr::DataSource::New(); - bool updateMode = false; - if (IsParameterEnabled("out") && HasValue("out")) { // Create new OGRDataSource output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite); @@ -184,32 +182,18 @@ void VectorPrediction<RegressionMode>::DoExecute() newLayer.CreateField(fieldDefn); } } - else - { - // Update mode - updateMode = true; - otbAppLogINFO("Update input vector data."); - // fill temporary buffer for the transfer - otb::ogr::Layer inputLayer = layer; - layer = buffer->CopyLayer(inputLayer, std::string("Buffer")); - // close input data source - source->Clear(); - // Re-open input data source in update mode - output = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Update_LayerUpdate); - } - otb::ogr::Layer outLayer = output->GetLayer(0); + return output; +} - OGRErr errStart = outLayer.ogr().StartTransaction(); - if (errStart != OGRERR_NONE) - { - itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << "."); - } +template <bool RegressionMode> +void +VectorPrediction<RegressionMode> +::AddPredictionField(otb::ogr::Layer & outLayer, otb::ogr::Layer const& layer, bool computeConfidenceMap) +{ OGRFeatureDefn& layerDefn = layer.GetLayerDefn(); - // Add the field of prediction in the output layer if field not exist - const OGRFieldType labelType = RegressionMode ? OFTReal : OFTInteger; int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str()); @@ -226,7 +210,6 @@ void VectorPrediction<RegressionMode>::DoExecute() } // Add confidence field in the output layer - std::string confFieldName("confidence"); if (computeConfidenceMap) { idx = layerDefn.GetFieldIndex(confFieldName.c_str()); @@ -244,8 +227,14 @@ void VectorPrediction<RegressionMode>::DoExecute() outLayer.CreateField(confFieldDefn); } } +} - // Fill output layer +template <bool RegressionMode> +void +VectorPrediction<RegressionMode> +::FillOutputLayer(otb::ogr::Layer & outLayer, otb::ogr::Layer const& layer, typename LabelListSampleType::Pointer target, + typename ConfidenceListSampleType::Pointer quality, bool updateMode, bool computeConfidenceMap) +{ unsigned int count = 0; std::string classfieldname = GetParameterString("cfield"); for (auto const& feature : layer) @@ -281,6 +270,62 @@ void VectorPrediction<RegressionMode>::DoExecute() } count++; } +} + +template <bool RegressionMode> +void VectorPrediction<RegressionMode>::DoExecute() +{ + m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode); + + if (m_Model.IsNull()) + { + otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); + } + + m_Model->SetRegressionMode(RegressionMode); + + m_Model->Load(GetParameterString("model")); + otbAppLogINFO("Model loaded"); + + auto shapefileName = GetParameterString("in"); + + auto source = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read); + auto layer = source->GetLayer(0); + + auto input = ReadInputListSample(layer); + + ListSampleType::Pointer listSample = NormalizeListSample(input); + + typename LabelListSampleType::Pointer target; + + // The quality listSample containing confidence values is defined here, but is only used when + // computeConfidenceMap evaluates to true. This listSample is also used in FillOutputLayer(...) + const bool computeConfidenceMap = shouldComputeConfidenceMap(); + typename ConfidenceListSampleType::Pointer quality; + + if (computeConfidenceMap) + { + quality = ConfidenceListSampleType::New(); + target = m_Model->PredictBatch(listSample, quality); + } + else + { + target = m_Model->PredictBatch(listSample); + } + + const bool updateMode = !(IsParameterEnabled("out") && HasValue("out")); + + auto output = CreateOutputDataSource(source, layer, updateMode); + otb::ogr::Layer outLayer = output->GetLayer(0); + + OGRErr errStart = outLayer.ogr().StartTransaction(); + if (errStart != OGRERR_NONE) + { + itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << "."); + } + + AddPredictionField(outLayer, layer, computeConfidenceMap); + FillOutputLayer(outLayer, layer, target, quality, updateMode, computeConfidenceMap); if (outLayer.ogr().TestCapability("Transactions")) {