Skip to content
Snippets Groups Projects
Commit 32b26f4c authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

Merge branch '1989-refac-readinputlistsample' into 'develop'

Precompute field's index in VectorClassifier

See merge request orfeotoolbox/otb!645
parents def9f46a 603e3fc1
Branches
Tags
No related merge requests found
...@@ -247,6 +247,11 @@ public: ...@@ -247,6 +247,11 @@ public:
* \sa \c OGRFeature::GetFieldDefnRef() * \sa \c OGRFeature::GetFieldDefnRef()
*/ */
FieldDefn GetFieldDefn(std::string const& name) const; FieldDefn GetFieldDefn(std::string const& name) const;
/** Searches the index of a field given a name.
* \invariant <tt>m_Feature != 0</tt>
* \throw itk::ExceptionObject if no field named \c name exists.
*/
int GetFieldIndex(std::string const& name) const;
//@} //@}
/**\name Geometries /**\name Geometries
...@@ -351,12 +356,6 @@ private: ...@@ -351,12 +356,6 @@ private:
void UncheckedSetGeometry(OGRGeometry const* geometry); void UncheckedSetGeometry(OGRGeometry const* geometry);
//@} //@}
/** Searches the index of a field given a name.
* \invariant <tt>m_Feature != 0</tt>
* \throw itk::ExceptionObject if no field named \c name exists.
*/
int GetFieldIndex(std::string const& name) const;
/** /**
* Checks whether the internal \c OGRFeature is non null. * Checks whether the internal \c OGRFeature is non null.
* Fires an assertion otherwise. * Fires an assertion otherwise.
......
...@@ -72,24 +72,36 @@ void VectorPrediction<RegressionMode>::DoUpdateParameters() ...@@ -72,24 +72,36 @@ void VectorPrediction<RegressionMode>::DoUpdateParameters()
} }
} }
template <bool RegressionMode> template <bool RegressionMode>
typename VectorPrediction<RegressionMode>::ListSampleType::Pointer VectorPrediction<RegressionMode>::ReadInputListSample(otb::ogr::Layer const& layer) typename VectorPrediction<RegressionMode>::ListSampleType::Pointer
VectorPrediction<RegressionMode>
::ReadInputListSample(otb::ogr::Layer const& layer)
{ {
typename ListSampleType::Pointer input = ListSampleType::New(); typename ListSampleType::Pointer input = ListSampleType::New();
const int nbFeatures = GetSelectedItems("feat").size(); const auto nbFeatures = GetSelectedItems("feat").size();
input->SetMeasurementVectorSize(nbFeatures); input->SetMeasurementVectorSize(nbFeatures);
std::vector<int> featureFieldIndex(nbFeatures, -1);
ogr::Layer::const_iterator it_feat = layer.cbegin();
for (unsigned int i = 0; i < nbFeatures; i++)
{
try
{
featureFieldIndex[i] = (*it_feat).GetFieldIndex(GetChoiceNames("feat")[GetSelectedItems("feat")[i]]);
}
catch(...)
{
otbAppLogFATAL("The field name for feature " << GetChoiceNames("feat")[GetSelectedItems("feat")[i]] << " has not been found" << std::endl);
}
}
for (auto const& feature : layer) for (auto const& feature : layer)
{ {
MeasurementType mv(nbFeatures); MeasurementType mv(nbFeatures);
for (int idx = 0; idx < nbFeatures; ++idx) for (unsigned int idx = 0; idx < nbFeatures; ++idx)
{ {
// Beware that itemIndex differs from ogr layer field index auto field = feature[featureFieldIndex[idx]];
unsigned int itemIndex = GetSelectedItems("feat")[idx];
std::string fieldName = GetChoiceNames("feat")[itemIndex];
auto field = feature[fieldName];
switch (field.GetType()) switch (field.GetType())
{ {
case OFTInteger: case OFTInteger:
...@@ -108,6 +120,7 @@ typename VectorPrediction<RegressionMode>::ListSampleType::Pointer VectorPredict ...@@ -108,6 +120,7 @@ typename VectorPrediction<RegressionMode>::ListSampleType::Pointer VectorPredict
return input; return input;
} }
template <bool RegressionMode> template <bool RegressionMode>
typename VectorPrediction<RegressionMode>::ListSampleType::Pointer VectorPrediction<RegressionMode>::NormalizeListSample(ListSampleType::Pointer input) typename VectorPrediction<RegressionMode>::ListSampleType::Pointer VectorPrediction<RegressionMode>::NormalizeListSample(ListSampleType::Pointer input)
{ {
...@@ -278,13 +291,11 @@ void VectorPrediction<RegressionMode>::DoExecute() ...@@ -278,13 +291,11 @@ void VectorPrediction<RegressionMode>::DoExecute()
auto shapefileName = GetParameterString("in"); auto shapefileName = GetParameterString("in");
auto source = otb::ogr::DataSource::New(shapefileName, otb::ogr::DataSource::Modes::Read); ogr::DataSource::Pointer source = ogr::DataSource::New(shapefileName, ogr::DataSource::Modes::Read);
auto layer = source->GetLayer(0); auto layer = source->GetLayer(0);
auto input = ReadInputListSample(layer); auto input = ReadInputListSample(layer);
ListSampleType::Pointer listSample = NormalizeListSample(input); ListSampleType::Pointer listSample = NormalizeListSample(input);
typename LabelListSampleType::Pointer target; typename LabelListSampleType::Pointer target;
// The quality listSample containing confidence values is defined here, but is only used when // The quality listSample containing confidence values is defined here, but is only used when
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment