Skip to content
Snippets Groups Projects
Commit 4cef0569 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: code review

parent 50bca421
No related branches found
No related tags found
No related merge requests found
......@@ -93,7 +93,7 @@ void VectorRegression::DoInitSpecialization()
SetOfficialDocLink();
}
// Confidence map computation is not support for regression.
// Confidence map computation is not supported for regression.
template <>
bool VectorRegression::shouldComputeConfidenceMap() const
{
......
......@@ -49,36 +49,36 @@ class VectorPrediction : public Application
{
public:
/** Standard class typedefs. */
typedef VectorPrediction Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
using Self = VectorPrediction;
using Superclass = Application;
using Pointer = itk::SmartPointer<Self>;
using ConstPointer = itk::SmartPointer<const Self>;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(Self, Application)
/** Filters typedef */
typedef float ValueType;
/** Filters typedef */
using ValueType = float;
// Label type is float for regression and unsigned int for classification
typedef typename std::conditional<RegressionMode, float, unsigned int>::type LabelType;
using LabelType = typename std::conditional<RegressionMode, float, unsigned int>::type;
typedef itk::FixedArray<LabelType, 1> LabelSampleType;
typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType;
using LabelSampleType = itk::FixedArray<LabelType, 1>;
using LabelListSampleType = itk::Statistics::ListSample<LabelSampleType>;
typedef otb::MachineLearningModel<ValueType, LabelType> MachineLearningModelType;
typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
typedef typename MachineLearningModelType::Pointer ModelPointerType;
typedef typename MachineLearningModelType::ConfidenceListSampleType ConfidenceListSampleType;
using MachineLearningModelType = otb::MachineLearningModel<ValueType, LabelType>;
using MachineLearningModelFactoryType = otb::MachineLearningModelFactory<ValueType, LabelType>;
using ModelPointerType = typename MachineLearningModelType::Pointer;
using ConfidenceListSampleType = typename MachineLearningModelType::ConfidenceListSampleType;
/** Statistics Filters typedef */
typedef itk::VariableLengthVector<ValueType> MeasurementType;
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
using MeasurementType = itk::VariableLengthVector<ValueType>;
using StatisticsReader = otb::StatisticsXMLFileReader<MeasurementType>;
typedef itk::VariableLengthVector<ValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> ListSampleType;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
using InputSampleType = itk::VariableLengthVector<ValueType>;
using ListSampleType = itk::Statistics::ListSample<InputSampleType>;
using ShiftScaleFilterType = otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType>;
~VectorPrediction() override
{
......
......@@ -47,24 +47,23 @@ void VectorPrediction<RegressionMode>::DoUpdateParameters()
{
if (HasValue("in"))
{
std::string shapefile = GetParameterString("in");
auto shapefileName = GetParameterString("in");
otb::ogr::DataSource::Pointer ogrDS;
ogrDS = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
otb::ogr::Layer layer = ogrDS->GetLayer(0);
auto ogrDS = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read);
auto layer = ogrDS->GetLayer(0);
OGRFeatureDefn& layerDefn = layer.GetLayerDefn();
ClearChoices("feat");
for (int iField = 0; iField < layerDefn.GetFieldCount(); iField++)
{
std::string item = layerDefn.GetFieldDefn(iField)->GetNameRef();
auto fieldDefn = layerDefn.GetFieldDefn(iField);
std::string item = fieldDefn->GetNameRef();
std::string key(item);
key.erase(std::remove_if(key.begin(), key.end(), IsNotAlphaNum), key.end());
std::transform(key.begin(), key.end(), key.begin(), tolower);
OGRFieldType fieldType = layerDefn.GetFieldDefn(iField)->GetType();
auto fieldType = fieldDefn->GetType();
if (fieldType == OFTInteger || fieldType == OFTInteger64 || fieldType == OFTReal)
{
std::string tmpKey = "feat." + key;
......@@ -77,21 +76,17 @@ void VectorPrediction<RegressionMode>::DoUpdateParameters()
template <bool RegressionMode>
void VectorPrediction<RegressionMode>::DoExecute()
{
clock_t tic = clock();
auto shapefileName = GetParameterString("in");
std::string shapefile = GetParameterString("in");
otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
otb::ogr::Layer layer = source->GetLayer(0);
auto source = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read);
auto layer = source->GetLayer(0);
typename ListSampleType::Pointer input = ListSampleType::New();
const int nbFeatures = GetSelectedItems("feat").size();
input->SetMeasurementVectorSize(nbFeatures);
otb::ogr::Layer::const_iterator it = layer.cbegin();
otb::ogr::Layer::const_iterator itEnd = layer.cend();
for (; it != itEnd; ++it)
for (auto const& feature : layer)
{
MeasurementType mv;
mv.SetSize(nbFeatures);
......@@ -100,19 +95,18 @@ void VectorPrediction<RegressionMode>::DoExecute()
// Beware that itemIndex differs from ogr layer field index
unsigned int itemIndex = GetSelectedItems("feat")[idx];
std::string fieldName = GetChoiceNames("feat")[itemIndex];
switch ((*it)[fieldName].GetType())
auto field = feature[fieldName];
switch (field.GetType())
{
case OFTInteger:
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
break;
case OFTInteger64:
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
case OFTInteger || OFTInteger64:
mv[idx] = static_cast<ValueType>(field.template GetValue<int>());
break;
case OFTReal:
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<double>());
mv[idx] = static_cast<ValueType>(field.template GetValue<double>());
break;
default:
itkExceptionMacro(<< "incorrect field type: " << (*it)[fieldName].GetType() << ".");
itkExceptionMacro(<< "incorrect field type: " << field.GetType() << ".");
}
}
input->PushBack(mv);
......@@ -202,7 +196,7 @@ void VectorPrediction<RegressionMode>::DoExecute()
// close input data source
source->Clear();
// Re-open input data source in update mode
output = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Update_LayerUpdate);
output = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Update_LayerUpdate);
}
otb::ogr::Layer outLayer = output->GetLayer(0);
......@@ -217,11 +211,7 @@ void VectorPrediction<RegressionMode>::DoExecute()
// Add the field of prediction in the output layer if field not exist
OGRFieldType labelType;
if (RegressionMode == true)
labelType = OFTReal;
else
labelType = OFTInteger;
const OGRFieldType labelType = RegressionMode ? OFTReal : OFTInteger;
int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
if (idx >= 0)
......@@ -259,29 +249,25 @@ void VectorPrediction<RegressionMode>::DoExecute()
// Fill output layer
unsigned int count = 0;
std::string classfieldname = GetParameterString("cfield");
it = layer.cbegin();
itEnd = layer.cend();
for (; it != itEnd; ++it, ++count)
for (auto const& feature : layer)
{
ogr::Feature dstFeature(outLayer.GetLayerDefn());
dstFeature.SetFrom(*it, TRUE);
dstFeature.SetFID(it->GetFID());
switch (dstFeature[classfieldname].GetType())
dstFeature.SetFrom(feature, TRUE);
dstFeature.SetFID(feature.GetFID());
auto field = dstFeature[classfieldname];
switch (field.GetType())
{
case OFTInteger:
dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
break;
case OFTInteger64:
dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
case OFTInteger || OFTInteger64:
field.template SetValue<int>(target->GetMeasurementVector(count)[0]);
break;
case OFTReal:
dstFeature[classfieldname].SetValue<double>(target->GetMeasurementVector(count)[0]);
field.template SetValue<double>(target->GetMeasurementVector(count)[0]);
break;
case OFTString:
dstFeature[classfieldname].SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
field.template SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
break;
default:
itkExceptionMacro(<< "incorrect field type: " << dstFeature[classfieldname].GetType() << ".");
itkExceptionMacro(<< "incorrect field type: " << field.GetType() << ".");
}
if (computeConfidenceMap)
dstFeature[confFieldName].SetValue<double>(quality->GetMeasurementVector(count)[0]);
......@@ -305,9 +291,6 @@ void VectorPrediction<RegressionMode>::DoExecute()
}
output->SyncToDisk();
clock_t toc = clock();
otbAppLogINFO("Elapsed: " << ((double)(toc - tic) / CLOCKS_PER_SEC) << " seconds.");
}
} // end namespace wrapper
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment