Commit ec1e76ee authored by Cédric Traizet's avatar Cédric Traizet

Merge branch 'field_get_value_assert' into 'develop'

Type assertion in OGRFieldWrapper GetValue()

See merge request !515
parents 5f04ba3c 24ea0818
Pipeline #2121 passed with stages
in 77 minutes and 13 seconds
......@@ -215,8 +215,22 @@ private:
// Beware that itemIndex differs from ogr layer field index
unsigned int itemIndex = GetSelectedItems("feat")[idx];
std::string fieldName = GetChoiceNames( "feat" )[itemIndex];
switch ((*it)[fieldName].GetType())
{
case OFTInteger:
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
break;
case OFTInteger64:
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<int>());
break;
case OFTReal:
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<double>());
break;
default:
itkExceptionMacro(<< "incorrect field type: " << (*it)[fieldName].GetType() << ".");
}
mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<double>());
}
input->PushBack(mv);
}
......@@ -369,7 +383,23 @@ private:
ogr::Feature dstFeature(outLayer.GetLayerDefn());
dstFeature.SetFrom( *it , TRUE);
dstFeature.SetFID(it->GetFID());
dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
switch (dstFeature[classfieldname].GetType())
{
case OFTInteger:
dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
break;
case OFTInteger64:
dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
break;
case OFTReal:
dstFeature[classfieldname].SetValue<double>(target->GetMeasurementVector(count)[0]);
break;
case OFTString:
dstFeature[classfieldname].SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
break;
default:
itkExceptionMacro(<< "incorrect field type: " << dstFeature[classfieldname].GetType() << ".");
}
if (computeConfidenceMap)
dstFeature[confFieldName].SetValue<double>(quality->GetMeasurementVector(count)[0]);
if (updateMode)
......
......@@ -65,7 +65,7 @@ public:
typedef typename Superclass::SampleType SampleType;
typedef typename Superclass::ListSampleType ListSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef double ValueType;
typedef itk::VariableLengthVector <ValueType> MeasurementType;
......@@ -127,7 +127,10 @@ protected:
{
m_SelectedCFieldIdx = selectedCFieldIdx;
// Handle only one class field name, if several are provided only the first one is used.
m_SelectedCFieldName = selectedCFieldIdx.empty() ? cFieldNames.front() : cFieldNames[selectedCFieldIdx.front()];
if (selectedCFieldIdx.empty())
m_SelectedCFieldName.clear();
else
m_SelectedCFieldName = cFieldNames[selectedCFieldIdx.front()];
}
};
......@@ -185,12 +188,6 @@ protected:
void DoInit() override;
void DoUpdateParameters() override;
void DoExecute() override;
private:
/**
* Get the field of the input feature corresponding to the input field
*/
inline TOutputValue GetFeatureField(const ogr::Feature& feature, int field);
};
}
......
......@@ -237,23 +237,6 @@ TrainVectorBase<TInputValue, TOutputValue>
return measurement;
}
// Template specialization for the integer case (i.e.classification), to avoid a cast from double to integer
template <>
inline int
TrainVectorBase<float, int>
::GetFeatureField(const ogr::Feature & feature, int fieldIndex)
{
return(feature[fieldIndex].GetValue<int>());
}
template <class TInputValue, class TOutputValue>
inline TOutputValue
TrainVectorBase<TInputValue, TOutputValue>
::GetFeatureField(const ogr::Feature & feature, int fieldIndex)
{
return(feature[fieldIndex].GetValue<double>());
}
template <class TInputValue, class TOutputValue>
typename TrainVectorBase<TInputValue, TOutputValue>::SamplesWithLabel
TrainVectorBase<TInputValue, TOutputValue>
......@@ -310,12 +293,45 @@ TrainVectorBase<TInputValue, TOutputValue>
MeasurementType mv;
mv.SetSize( m_FeaturesInfo.m_NbFeatures );
for( unsigned int idx = 0; idx < m_FeaturesInfo.m_NbFeatures; ++idx )
mv[idx] = feature[featureFieldIndex[idx]].GetValue<double>();
{
switch (feature[featureFieldIndex[idx]].GetType())
{
case OFTInteger:
mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<int>());
break;
case OFTInteger64:
mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<int>());
break;
case OFTReal:
mv[idx] = static_cast<ValueType>(feature[featureFieldIndex[idx]].GetValue<double>());
break;
default:
itkExceptionMacro(<< "incorrect field type: " << feature[featureFieldIndex[idx]].GetType() << ".");
}
}
input->PushBack( mv );
if(cFieldIndex>=0 && ogr::Field(feature,cFieldIndex).HasBeenSet())
target->PushBack(GetFeatureField(feature,cFieldIndex));
{
switch (feature[cFieldIndex].GetType())
{
case OFTInteger:
target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
break;
case OFTInteger64:
target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<int>()));
break;
case OFTReal:
target->PushBack(static_cast<ValueType>(feature[cFieldIndex].GetValue<double>()));
break;
case OFTString:
target->PushBack(static_cast<ValueType>(std::stod(feature[cFieldIndex].GetValue<std::string>())));
break;
default:
itkExceptionMacro(<< "incorrect field type: " << feature[featureFieldIndex[cFieldIndex]].GetType() << ".");
}
}
else
target->PushBack( 0. );
......
......@@ -129,7 +129,9 @@ private:
otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
otb::ogr::Layer layer = source->GetLayer(0);
ListSampleType::Pointer input = ListSampleType::New();
const int nbFeatures = GetParameterStringList("feat").size();
const auto inputIndexes = GetParameterStringList("feat");
const int nbFeatures = inputIndexes.size();
input->SetMeasurementVectorSize(nbFeatures);
otb::ogr::Layer::const_iterator it = layer.cbegin();
......@@ -140,7 +142,20 @@ private:
mv.SetSize(nbFeatures);
for(int idx=0; idx < nbFeatures; ++idx)
{
mv[idx] = (*it)[GetParameterStringList("feat")[idx]].GetValue<double>();
switch ((*it)[inputIndexes[idx]].GetType())
{
case OFTInteger:
mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<int>());
break;
case OFTInteger64:
mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<int>());
break;
case OFTReal:
mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<double>());
break;
default:
itkExceptionMacro(<< "incorrect field type: " << (*it)[inputIndexes[idx]].GetType() << ".");
}
}
input->PushBack(mv);
}
......
......@@ -221,7 +221,20 @@ private:
for(int idx=0; idx < nbFeatures; ++idx)
{
mv[idx] = static_cast<float>( (*it)[inputIndexes[idx]].GetValue<double>() );
switch ((*it)[inputIndexes[idx]].GetType())
{
case OFTInteger:
mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<int>());
break;
case OFTInteger64:
mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<int>());
break;
case OFTReal:
mv[idx] = static_cast<ValueType>((*it)[inputIndexes[idx]].GetValue<double>());
break;
default:
itkExceptionMacro(<< "incorrect field type: " << (*it)[inputIndexes[idx]].GetType() << ".");
}
}
input->PushBack(mv);
}
......@@ -399,7 +412,23 @@ private:
for (std::size_t i=0; i<outFields.size(); ++i)
{
dstFeature[outFields[i]].SetValue<double>(target->GetMeasurementVector(count)[i]);
switch (dstFeature[outFields[i]].GetType())
{
case OFTInteger:
dstFeature[outFields[i]].SetValue<int>(target->GetMeasurementVector(count)[0]);
break;
case OFTInteger64:
dstFeature[outFields[i]].SetValue<int>(target->GetMeasurementVector(count)[0]);
break;
case OFTReal:
dstFeature[outFields[i]].SetValue<double>(target->GetMeasurementVector(count)[0]);
break;
case OFTString:
dstFeature[outFields[i]].SetValue<std::string>(std::to_string(target->GetMeasurementVector(count)[0]));
break;
default:
itkExceptionMacro(<< "incorrect field type: " << dstFeature[outFields[i]].GetType() << ".");
}
}
if (updateMode)
{
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment