From d4a174c0b59ac06f076c79fa7799e6888a940362 Mon Sep 17 00:00:00 2001 From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr> Date: Wed, 28 Feb 2018 18:29:59 +0100 Subject: [PATCH] ENH: implement sample augmentation as a filter --- .../app/otbSampleAugmentation.cxx | 181 ++---------- .../include/otbSampleAugmentation.h | 4 +- .../include/otbSampleAugmentationFilter.h | 168 +++++++++++ .../include/otbSampleAugmentationFilter.txx | 268 ++++++++++++++++++ 4 files changed, 463 insertions(+), 158 deletions(-) create mode 100644 Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h create mode 100644 Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx diff --git a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx index a12e3912b1..eed67c8fb8 100644 --- a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx +++ b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx @@ -21,7 +21,7 @@ #include "otbWrapperApplication.h" #include "otbWrapperApplicationFactory.h" #include "otbOGRDataSourceWrapper.h" -#include "otbSampleAugmentation.h" +#include "otbSampleAugmentationFilter.h" namespace otb { @@ -44,9 +44,9 @@ public: itkTypeMacro(SampleAugmentation, otb::Application); /** Filters typedef */ - using SampleType = sampleAugmentation::SampleType; - using SampleVectorType = sampleAugmentation::SampleVectorType; - + using FilterType = otb::SampleAugmentationFilter; + using SampleType = FilterType::SampleType; + using SampleVectorType = FilterType::SampleVectorType; private: SampleAugmentation() {} @@ -220,143 +220,49 @@ private: GetSelectedItems( "exclude" )); for(const auto& ef : excludedFeatures) otbAppLogINFO("Excluding feature " << ef << '\n'); - auto inSamples = extractSamples(vectors, this->GetParameterInt("layer"), - fieldName, - this->GetParameterInt("label"), - excludedFeatures); + int seed = std::time(nullptr); if(IsParameterEnabled("seed")) seed = this->GetParameterInt("seed"); - SampleVectorType newSamples; + + + FilterType::Pointer filter = FilterType::New(); + filter->SetInput(vectors); + filter->SetLayer(this->GetParameterInt("layer")); + filter->SetNumberOfSamples(this->GetParameterInt("samples")); + filter->SetOutputSamples(output); + filter->SetClassFieldName(fieldName); + filter->SetLabel(this->GetParameterInt("label")); + filter->SetExcludedFeatures(excludedFeatures); + filter->SetSeed(seed); switch (this->GetParameterInt("strategy")) { // replicate case 0: { otbAppLogINFO("Augmentation strategy : replicate"); - sampleAugmentation::replicateSamples(inSamples, this->GetParameterInt("samples"), - newSamples); + filter->SetStrategy(FilterType::Strategy::Replicate); } - break; + break; // jitter case 1: { otbAppLogINFO("Augmentation strategy : jitter"); - sampleAugmentation::jitterSamples(inSamples, this->GetParameterInt("samples"), - newSamples, - this->GetParameterFloat("strategy.jitter.stdfactor"), - seed); + filter->SetStrategy(FilterType::Strategy::Jitter); + filter->SetStdFactor(this->GetParameterFloat("stdfactor")); } break; case 2: { otbAppLogINFO("Augmentation strategy : smote"); - sampleAugmentation::smote(inSamples, this->GetParameterInt("samples"), - newSamples, - this->GetParameterInt("strategy.smote.neighbors"), - seed); + filter->SetStrategy(FilterType::Strategy::Smote); + filter->SetSmoteNeighbors(this->GetParameterInt("neighbors")); } break; } - writeSamples(vectors, output, newSamples, this->GetParameterInt("layer"), - fieldName, - this->GetParameterInt("label"), - excludedFeatures); + filter->Update(); output->SyncToDisk(); } -/** Extracts the samples of a single class from the vector data to a -* vector and excludes some unwanted features. -*/ - SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors, - size_t layerName, - const std::string& classField, const int label, - const std::vector<std::string>& excludedFeatures = {}) - { - ogr::Layer layer = vectors->GetLayer(layerName); - ogr::Feature feature = layer.ogr().GetNextFeature(); - if(feature.addr() == 0) - { - otbAppLogFATAL("Layer " << layerName << " of input sample file is empty.\n"); - } - int cFieldIndex = feature.ogr().GetFieldIndex( classField.c_str() ); - if( cFieldIndex < 0 ) - { - otbAppLogFATAL( "The field name for class label (" << classField - << ") has not been found in the vector file " ); - } - - auto numberOfFields = feature.ogr().GetFieldCount(); - auto excludedIds = getExcludedFeaturesIds(excludedFeatures, layer); - otbAppLogINFO("The vector file contains " << numberOfFields << " fields.\n"); - SampleVectorType samples; - bool goesOn{feature.addr() != 0}; - while( goesOn ) - { - // Retrieve all the features for each field in the ogr layer. - if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label) - { - - SampleType mv; - for(auto idx=0; idx<numberOfFields; ++idx) - { - if(excludedIds.find(idx) == excludedIds.cend() && - isNumericField(feature, idx)) - mv.push_back(feature.ogr().GetFieldAsDouble(idx)); - } - samples.push_back(mv); - } - feature = layer.ogr().GetNextFeature(); - goesOn = feature.addr() != 0; - } - return samples; - } - - void writeSamples(const ogr::DataSource::Pointer& vectors, - ogr::DataSource::Pointer& output, - const SampleVectorType& samples, - const size_t layerName, - const std::string& classField, int label, - const std::vector<std::string>& excludedFeatures = {}) - { - - auto inputLayer = vectors->GetLayer(layerName); - auto excludedIds = getExcludedFeaturesIds(excludedFeatures, inputLayer); - - OGRSpatialReference * oSRS = nullptr; - if (inputLayer.GetSpatialRef()) - { - oSRS = inputLayer.GetSpatialRef()->Clone(); - } - OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn(); - - auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS, - inputLayer.GetGeomType()); - for (int k=0 ; k < layerDefn.GetFieldCount() ; k++) - { - OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k)); - ogr::FieldDefn fieldDefn(originDefn); - outputLayer.CreateField(fieldDefn); - } - - auto featureCount = outputLayer.GetFeatureCount(false); - auto templateFeature = selectTemplateFeature(inputLayer, classField, label); - for(const auto& sample : samples) - { - ogr::Feature dstFeature(outputLayer.GetLayerDefn()); - dstFeature.SetFrom( templateFeature, TRUE ); - dstFeature.SetFID(++featureCount); - auto sampleFieldCounter = 0; - for (int k=0 ; k < layerDefn.GetFieldCount() ; k++) - { - if(excludedIds.find(k) == excludedIds.cend() && - isNumericField(dstFeature, k)) - { - dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]); - } - } - outputLayer.CreateFeature( dstFeature ); - } - } std::vector<std::string> GetExcludedFeatures(const std::vector<std::string>& fieldNames, const std::vector<int>& selectedIdx) @@ -369,45 +275,8 @@ private: } return result; } - ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer, - const std::string& classField, int label) - { - auto featureIt = inputLayer.begin(); - bool goesOn{(*featureIt).addr() != 0}; - while( goesOn ) - { - if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label) - { - return *featureIt; - } - ++featureIt; - } - return *(inputLayer.begin()); - } - std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures, - const ogr::Layer& inputLayer) - { - auto feature = *(inputLayer).begin(); - std::set<size_t> excludedIds; - if( excludedFeatures.size() != 0) - { - for(const auto& fieldName : excludedFeatures) - { - auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() ); - excludedIds.insert(idx); - } - } - return excludedIds; - } - bool isNumericField(const ogr::Feature& feature, - const int idx) - { - OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType(); - return (fieldType == OFTInteger - || ogr::version_proxy::IsOFTInteger64( fieldType ) - || fieldType == OFTReal); - } - }; + +}; } // end of namespace Wrapper } // end of namespace otb diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentation.h b/Modules/Applications/AppClassification/include/otbSampleAugmentation.h index 43fd6657a0..432dbe8a26 100644 --- a/Modules/Applications/AppClassification/include/otbSampleAugmentation.h +++ b/Modules/Applications/AppClassification/include/otbSampleAugmentation.h @@ -199,7 +199,7 @@ void smote(const SampleVectorType& inSamples, } } -} -} +}//end namespaces sampleAugmentation +}//end namespace otb #endif diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h new file mode 100644 index 0000000000..754e96ef33 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.h @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef otbSampleAugmentationFilter_h +#define otbSampleAugmentationFilter_h + +#include "itkProcessObject.h" +#include "otbOGRDataSourceWrapper.h" +#include "otbSampleAugmentation.h" + +namespace otb +{ + +/** \class SampleAugmentationFilter +This class + */ + +class ITK_EXPORT SampleAugmentationFilter : + public itk::ProcessObject +{ +public: + + /** typedef for the classes standards. */ + typedef SampleAugmentationFilter Self; + typedef itk::ProcessObject Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Method for management of the object factory. */ + itkNewMacro(Self); + + /** Return the name of the class. */ + itkTypeMacro(SampleAugmentationFilter, ProcessObject); + + typedef ogr::DataSource OGRDataSourceType; + typedef typename OGRDataSourceType::Pointer OGRDataSourcePointerType; + typedef ogr::Layer OGRLayerType; + + typedef itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType; + + using SampleType = sampleAugmentation::SampleType; + using SampleVectorType = sampleAugmentation::SampleVectorType; + + enum class Strategy { Replicate, Jitter, Smote }; + + /** Set/Get the input OGRDataSource of this process object. */ + using Superclass::SetInput; + virtual void SetInput(const OGRDataSourceType* ds); + const OGRDataSourceType* GetInput(unsigned int idx); + + virtual void SetOutputSamples(ogr::DataSource* data); + + /** Set the Field Name in which labels will be written. (default is "class") + * A field "ClassFieldName" of type integer is created in the output memory layer. + */ + itkSetMacro(ClassFieldName, std::string); + /** + * Return the Field name in which labels have been written. + */ + itkGetMacro(ClassFieldName, std::string); + + + itkSetMacro(Layer, size_t); + itkGetMacro(Layer, size_t); + itkSetMacro(Label, int); + itkGetMacro(Label, int); + void SetStrategy(Strategy s) + { + m_Strategy = s; + } + Strategy GetStrategy() const + { + return m_Strategy; + } + itkSetMacro(NumberOfSamples, int); + itkGetMacro(NumberOfSamples, int); + void SetExcludedFeatures(const std::vector<std::string>& ef) + { + m_ExcludedFeatures = ef; + } + std::vector<std::string> GetExcludedFeatures() const + { + return m_ExcludedFeatures; + } + itkSetMacro(StdFactor, double); + itkGetMacro(StdFactor, double); + itkSetMacro(SmoteNeighbors, size_t); + itkGetMacro(SmoteNeighbors, size_t); + itkSetMacro(Seed, int); + itkGetMacro(Seed, int); +/** + * Get the output \c ogr::DataSource which is a "memory" datasource. + */ + const OGRDataSourceType * GetOutput(); + +protected: + SampleAugmentationFilter(); + ~SampleAugmentationFilter() ITK_OVERRIDE {} + + /** Generate Data method*/ + void GenerateData() ITK_OVERRIDE; + + /** DataObject pointer */ + typedef itk::DataObject::Pointer DataObjectPointer; + + DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) ITK_OVERRIDE; + using Superclass::MakeOutput; + + + SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors, + size_t layerName, + const std::string& classField, const int label, + const std::vector<std::string>& excludedFeatures = {}); + + void sampleToOGRFeatures(const ogr::DataSource::Pointer& vectors, + ogr::DataSource* output, + const SampleVectorType& samples, + const size_t layerName, + const std::string& classField, int label, + const std::vector<std::string>& excludedFeatures = {}); + +std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures, + const ogr::Layer& inputLayer); +bool isNumericField(const ogr::Feature& feature, const int idx); + +ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer, + const std::string& classField, int label); +private: + SampleAugmentationFilter(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + std::string m_ClassFieldName; + size_t m_Layer; + int m_Label; + std::vector<std::string> m_ExcludedFeatures; + Strategy m_Strategy; + int m_NumberOfSamples; + double m_StdFactor; + size_t m_SmoteNeighbors; + int m_Seed; + +}; + + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbSampleAugmentationFilter.txx" +#endif + +#endif diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx new file mode 100644 index 0000000000..41590a0109 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbSampleAugmentationFilter.txx @@ -0,0 +1,268 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef otbSampleAugmentationFilter_txx +#define otbSampleAugmentationFilter_txx + +#include "otbSampleAugmentationFilter.h" +#include "stdint.h" //needed for uintptr_t + +namespace otb +{ + +SampleAugmentationFilter +::SampleAugmentationFilter() : m_ClassFieldName("class") +{ + this->SetNumberOfRequiredInputs(1); + this->SetNumberOfRequiredOutputs(1); + this->ProcessObject::SetNthOutput(0, this->MakeOutput(0) ); +} + + +typename SampleAugmentationFilter::DataObjectPointer +SampleAugmentationFilter +::MakeOutput(DataObjectPointerArraySizeType itkNotUsed(idx)) +{ + return static_cast< DataObjectPointer >(OGRDataSourceType::New().GetPointer()); +} + +const typename SampleAugmentationFilter::OGRDataSourceType * +SampleAugmentationFilter +::GetOutput() +{ + return static_cast< const OGRDataSourceType *>( + this->ProcessObject::GetOutput(0)); +} + +void +SampleAugmentationFilter +::SetInput(const otb::ogr::DataSource* ds) +{ + this->Superclass::SetNthInput(0, const_cast<otb::ogr::DataSource *>(ds)); +} + +const typename SampleAugmentationFilter::OGRDataSourceType * +SampleAugmentationFilter +::GetInput(unsigned int idx) +{ + return static_cast<const OGRDataSourceType *> + (this->itk::ProcessObject::GetInput(idx)); +} + +void +SampleAugmentationFilter +::SetOutputSamples(ogr::DataSource* data) +{ + this->SetNthOutput(0,data); +} + + +void +SampleAugmentationFilter +::GenerateData(void) +{ + + OGRDataSourcePointerType inputDS = dynamic_cast<OGRDataSourceType*>(this->itk::ProcessObject::GetInput(0)); + auto outputDS = static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(1)); + auto inSamples = this->extractSamples(inputDS, m_Layer, + m_ClassFieldName, + m_Label, + m_ExcludedFeatures); + SampleVectorType newSamples; + switch (m_Strategy) + { + case Strategy::Replicate: + { + sampleAugmentation::replicateSamples(inSamples, m_NumberOfSamples, + newSamples); + } + break; + case Strategy::Jitter: + { + sampleAugmentation::jitterSamples(inSamples, m_NumberOfSamples, + newSamples, + m_StdFactor, + m_Seed); + } + break; + case Strategy::Smote: + { + sampleAugmentation::smote(inSamples, m_NumberOfSamples, + newSamples, + m_SmoteNeighbors, + m_Seed); + } + break; + } + this->sampleToOGRFeatures(inputDS, outputDS, newSamples, m_Layer, + m_ClassFieldName, + m_Label, + m_ExcludedFeatures); + + + // this->SetNthOutput(0,outputDS); +} + +/** Extracts the samples of a single class from the vector data to a +* vector and excludes some unwanted features. +*/ +SampleAugmentationFilter::SampleVectorType +SampleAugmentationFilter +::extractSamples(const ogr::DataSource::Pointer vectors, + size_t layerName, + const std::string& classField, const int label, + const std::vector<std::string>& excludedFeatures) +{ + ogr::Layer layer = vectors->GetLayer(layerName); + ogr::Feature feature = layer.ogr().GetNextFeature(); + if(feature.addr() == 0) + { + itkExceptionMacro("Layer " << layerName << " of input sample file is empty.\n"); + } + int cFieldIndex = feature.ogr().GetFieldIndex( classField.c_str() ); + if( cFieldIndex < 0 ) + { + itkExceptionMacro( "The field name for class label (" << classField + << ") has not been found in the vector file " ); + } + + auto numberOfFields = feature.ogr().GetFieldCount(); + auto excludedIds = this->getExcludedFeaturesIds(excludedFeatures, layer); + SampleVectorType samples; + bool goesOn{feature.addr() != 0}; + while( goesOn ) + { + // Retrieve all the features for each field in the ogr layer. + if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label) + { + + SampleType mv; + for(auto idx=0; idx<numberOfFields; ++idx) + { + if(excludedIds.find(idx) == excludedIds.cend() && + this->isNumericField(feature, idx)) + mv.push_back(feature.ogr().GetFieldAsDouble(idx)); + } + samples.push_back(mv); + } + feature = layer.ogr().GetNextFeature(); + goesOn = feature.addr() != 0; + } + return samples; +} + +void +SampleAugmentationFilter +::sampleToOGRFeatures(const ogr::DataSource::Pointer& vectors, + ogr::DataSource* output, + const SampleAugmentationFilter::SampleVectorType& samples, + const size_t layerName, + const std::string& classField, int label, + const std::vector<std::string>& excludedFeatures) +{ + + auto inputLayer = vectors->GetLayer(layerName); + auto excludedIds = this->getExcludedFeaturesIds(excludedFeatures, inputLayer); + + OGRSpatialReference * oSRS = nullptr; + if (inputLayer.GetSpatialRef()) + { + oSRS = inputLayer.GetSpatialRef()->Clone(); + } + OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn(); + + auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS, + inputLayer.GetGeomType()); + for (int k=0 ; k < layerDefn.GetFieldCount() ; k++) + { + OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k)); + ogr::FieldDefn fieldDefn(originDefn); + outputLayer.CreateField(fieldDefn); + } + + auto featureCount = outputLayer.GetFeatureCount(false); + auto templateFeature = this->selectTemplateFeature(inputLayer, classField, label); + for(const auto& sample : samples) + { + ogr::Feature dstFeature(outputLayer.GetLayerDefn()); + dstFeature.SetFrom( templateFeature, TRUE ); + dstFeature.SetFID(++featureCount); + auto sampleFieldCounter = 0; + for (int k=0 ; k < layerDefn.GetFieldCount() ; k++) + { + if(excludedIds.find(k) == excludedIds.cend() && + this->isNumericField(dstFeature, k)) + { + dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]); + } + } + outputLayer.CreateFeature( dstFeature ); + } +} + + std::set<size_t> + SampleAugmentationFilter + ::getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures, + const ogr::Layer& inputLayer) + { + auto feature = *(inputLayer).begin(); + std::set<size_t> excludedIds; + if( excludedFeatures.size() != 0) + { + for(const auto& fieldName : excludedFeatures) + { + auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() ); + excludedIds.insert(idx); + } + } + return excludedIds; + } + + bool +SampleAugmentationFilter +::isNumericField(const ogr::Feature& feature, + const int idx) +{ + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType(); + return (fieldType == OFTInteger + || ogr::version_proxy::IsOFTInteger64( fieldType ) + || fieldType == OFTReal); +} + +ogr::Feature +SampleAugmentationFilter +::selectTemplateFeature(const ogr::Layer& inputLayer, + const std::string& classField, int label) +{ + auto featureIt = inputLayer.begin(); + bool goesOn{(*featureIt).addr() != 0}; + while( goesOn ) + { + if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label) + { + return *featureIt; + } + ++featureIt; + } + return *(inputLayer.begin()); +} +} // end namespace otb + +#endif -- GitLab