Skip to content
Snippets Groups Projects
Commit 26405bbf authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: split the DoExecute() method in several parts

parent f5c47b73
No related branches found
No related tags found
No related merge requests found
......@@ -98,7 +98,37 @@ private:
/** Method returning whether the confidence map should be computed, depending on the regression mode and input parameters */
bool shouldComputeConfidenceMap() const;
/** Method returning the input list sample from the input layer */
typename ListSampleType::Pointer ReadInputListSample(otb::ogr::Layer const& layer);
/** Normalize a list sample using the statistic file given */
typename ListSampleType::Pointer NormalizeListSample(ListSampleType::Pointer input);
/** Create the output DataSource, in update mode the input layer is buffered and the input
* data source is re opened in update mode. */
otb::ogr::DataSource::Pointer CreateOutputDataSource(otb::ogr::DataSource::Pointer source,
otb::ogr::Layer & layer,
bool updateMode);
/** Add a prediction field in the output layer if it does not exist.
* If computeConfidenceMap evaluates to true a confidence field will be
* added. */
void AddPredictionField(otb::ogr::Layer & outLayer,
otb::ogr::Layer const& layer,
bool computeConfidenceMap);
/** Fill the output layer with the predicted values and optionnaly the confidence */
void FillOutputLayer(otb::ogr::Layer & outLayer,
otb::ogr::Layer const& layer,
typename LabelListSampleType::Pointer target,
typename ConfidenceListSampleType::Pointer quality,
bool updateMode,
bool computeConfidenceMap);
ModelPointerType m_Model;
/** Name used for the confidence field */
std::string confFieldName = "confidence";
};
}
}
......
......@@ -73,13 +73,10 @@ void VectorPrediction<RegressionMode>::DoUpdateParameters()
}
template <bool RegressionMode>
void VectorPrediction<RegressionMode>::DoExecute()
typename VectorPrediction<RegressionMode>::ListSampleType::Pointer
VectorPrediction<RegressionMode>
::ReadInputListSample(otb::ogr::Layer const& layer)
{
auto shapefileName = GetParameterString("in");
auto source = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read);
auto layer = source->GetLayer(0);
typename ListSampleType::Pointer input = ListSampleType::New();
const int nbFeatures = GetSelectedItems("feat").size();
......@@ -110,7 +107,16 @@ void VectorPrediction<RegressionMode>::DoExecute()
}
input->PushBack(mv);
}
return input;
}
template <bool RegressionMode>
typename VectorPrediction<RegressionMode>::ListSampleType::Pointer
VectorPrediction<RegressionMode>
::NormalizeListSample(ListSampleType::Pointer input)
{
const int nbFeatures = GetSelectedItems("feat").size();
// Statistics for shift/scale
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
......@@ -139,39 +145,31 @@ void VectorPrediction<RegressionMode>::DoExecute()
otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
otbAppLogINFO("Loading model");
m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode);
if (m_Model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
}
m_Model->SetRegressionMode(RegressionMode);
m_Model->Load(GetParameterString("model"));
otbAppLogINFO("Model loaded");
ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput();
typename ConfidenceListSampleType::Pointer quality;
return trainingShiftScaleFilter->GetOutput();
}
bool computeConfidenceMap = shouldComputeConfidenceMap();
typename LabelListSampleType::Pointer target;
if (computeConfidenceMap)
template <bool RegressionMode>
otb::ogr::DataSource::Pointer
VectorPrediction<RegressionMode>
::CreateOutputDataSource(otb::ogr::DataSource::Pointer source, otb::ogr::Layer & layer, bool updateMode)
{
ogr::DataSource::Pointer output;
ogr::DataSource::Pointer buffer = ogr::DataSource::New();
if (updateMode)
{
quality = ConfidenceListSampleType::New();
target = m_Model->PredictBatch(listSample, quality);
// Update mode
otbAppLogINFO("Update input vector data.");
// fill temporary buffer for the transfer
otb::ogr::Layer inputLayer = layer;
layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
// close input data source
source->Clear();
// Re-open input data source in update mode
output = otb::ogr::DataSource::New(GetParameterString("in"), otb::ogr::DataSource::Modes::Update_LayerUpdate);
}
else
{
target = m_Model->PredictBatch(listSample);
}
ogr::DataSource::Pointer output;
ogr::DataSource::Pointer buffer = ogr::DataSource::New();
bool updateMode = false;
if (IsParameterEnabled("out") && HasValue("out"))
{
// Create new OGRDataSource
output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
......@@ -184,32 +182,18 @@ void VectorPrediction<RegressionMode>::DoExecute()
newLayer.CreateField(fieldDefn);
}
}
else
{
// Update mode
updateMode = true;
otbAppLogINFO("Update input vector data.");
// fill temporary buffer for the transfer
otb::ogr::Layer inputLayer = layer;
layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
// close input data source
source->Clear();
// Re-open input data source in update mode
output = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Update_LayerUpdate);
}
otb::ogr::Layer outLayer = output->GetLayer(0);
return output;
}
OGRErr errStart = outLayer.ogr().StartTransaction();
if (errStart != OGRERR_NONE)
{
itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
}
template <bool RegressionMode>
void
VectorPrediction<RegressionMode>
::AddPredictionField(otb::ogr::Layer & outLayer, otb::ogr::Layer const& layer, bool computeConfidenceMap)
{
OGRFeatureDefn& layerDefn = layer.GetLayerDefn();
// Add the field of prediction in the output layer if field not exist
const OGRFieldType labelType = RegressionMode ? OFTReal : OFTInteger;
int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
......@@ -226,7 +210,6 @@ void VectorPrediction<RegressionMode>::DoExecute()
}
// Add confidence field in the output layer
std::string confFieldName("confidence");
if (computeConfidenceMap)
{
idx = layerDefn.GetFieldIndex(confFieldName.c_str());
......@@ -244,8 +227,14 @@ void VectorPrediction<RegressionMode>::DoExecute()
outLayer.CreateField(confFieldDefn);
}
}
}
// Fill output layer
template <bool RegressionMode>
void
VectorPrediction<RegressionMode>
::FillOutputLayer(otb::ogr::Layer & outLayer, otb::ogr::Layer const& layer, typename LabelListSampleType::Pointer target,
typename ConfidenceListSampleType::Pointer quality, bool updateMode, bool computeConfidenceMap)
{
unsigned int count = 0;
std::string classfieldname = GetParameterString("cfield");
for (auto const& feature : layer)
......@@ -281,6 +270,62 @@ void VectorPrediction<RegressionMode>::DoExecute()
}
count++;
}
}
template <bool RegressionMode>
void VectorPrediction<RegressionMode>::DoExecute()
{
m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode);
if (m_Model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
}
m_Model->SetRegressionMode(RegressionMode);
m_Model->Load(GetParameterString("model"));
otbAppLogINFO("Model loaded");
auto shapefileName = GetParameterString("in");
auto source = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read);
auto layer = source->GetLayer(0);
auto input = ReadInputListSample(layer);
ListSampleType::Pointer listSample = NormalizeListSample(input);
typename LabelListSampleType::Pointer target;
// The quality listSample containing confidence values is defined here, but is only used when
// computeConfidenceMap evaluates to true. This listSample is also used in FillOutputLayer(...)
const bool computeConfidenceMap = shouldComputeConfidenceMap();
typename ConfidenceListSampleType::Pointer quality;
if (computeConfidenceMap)
{
quality = ConfidenceListSampleType::New();
target = m_Model->PredictBatch(listSample, quality);
}
else
{
target = m_Model->PredictBatch(listSample);
}
const bool updateMode = !(IsParameterEnabled("out") && HasValue("out"));
auto output = CreateOutputDataSource(source, layer, updateMode);
otb::ogr::Layer outLayer = output->GetLayer(0);
OGRErr errStart = outLayer.ogr().StartTransaction();
if (errStart != OGRERR_NONE)
{
itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
}
AddPredictionField(outLayer, layer, computeConfidenceMap);
FillOutputLayer(outLayer, layer, target, quality, updateMode, computeConfidenceMap);
if (outLayer.ogr().TestCapability("Transactions"))
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment