Skip to content
Snippets Groups Projects
Commit d4a174c0 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: implement sample augmentation as a filter

parent b0a4f4a5
No related branches found
No related tags found
No related merge requests found
......@@ -21,7 +21,7 @@
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbSampleAugmentation.h"
#include "otbSampleAugmentationFilter.h"
namespace otb
{
......@@ -44,9 +44,9 @@ public:
itkTypeMacro(SampleAugmentation, otb::Application);
/** Filters typedef */
using SampleType = sampleAugmentation::SampleType;
using SampleVectorType = sampleAugmentation::SampleVectorType;
using FilterType = otb::SampleAugmentationFilter;
using SampleType = FilterType::SampleType;
using SampleVectorType = FilterType::SampleVectorType;
private:
SampleAugmentation() {}
......@@ -220,143 +220,49 @@ private:
GetSelectedItems( "exclude" ));
for(const auto& ef : excludedFeatures)
otbAppLogINFO("Excluding feature " << ef << '\n');
auto inSamples = extractSamples(vectors, this->GetParameterInt("layer"),
fieldName,
this->GetParameterInt("label"),
excludedFeatures);
int seed = std::time(nullptr);
if(IsParameterEnabled("seed")) seed = this->GetParameterInt("seed");
SampleVectorType newSamples;
FilterType::Pointer filter = FilterType::New();
filter->SetInput(vectors);
filter->SetLayer(this->GetParameterInt("layer"));
filter->SetNumberOfSamples(this->GetParameterInt("samples"));
filter->SetOutputSamples(output);
filter->SetClassFieldName(fieldName);
filter->SetLabel(this->GetParameterInt("label"));
filter->SetExcludedFeatures(excludedFeatures);
filter->SetSeed(seed);
switch (this->GetParameterInt("strategy"))
{
// replicate
case 0:
{
otbAppLogINFO("Augmentation strategy : replicate");
sampleAugmentation::replicateSamples(inSamples, this->GetParameterInt("samples"),
newSamples);
filter->SetStrategy(FilterType::Strategy::Replicate);
}
break;
break;
// jitter
case 1:
{
otbAppLogINFO("Augmentation strategy : jitter");
sampleAugmentation::jitterSamples(inSamples, this->GetParameterInt("samples"),
newSamples,
this->GetParameterFloat("strategy.jitter.stdfactor"),
seed);
filter->SetStrategy(FilterType::Strategy::Jitter);
filter->SetStdFactor(this->GetParameterFloat("stdfactor"));
}
break;
case 2:
{
otbAppLogINFO("Augmentation strategy : smote");
sampleAugmentation::smote(inSamples, this->GetParameterInt("samples"),
newSamples,
this->GetParameterInt("strategy.smote.neighbors"),
seed);
filter->SetStrategy(FilterType::Strategy::Smote);
filter->SetSmoteNeighbors(this->GetParameterInt("neighbors"));
}
break;
}
writeSamples(vectors, output, newSamples, this->GetParameterInt("layer"),
fieldName,
this->GetParameterInt("label"),
excludedFeatures);
filter->Update();
output->SyncToDisk();
}
/** Extracts the samples of a single class from the vector data to a
* vector and excludes some unwanted features.
*/
SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors,
size_t layerName,
const std::string& classField, const int label,
const std::vector<std::string>& excludedFeatures = {})
{
ogr::Layer layer = vectors->GetLayer(layerName);
ogr::Feature feature = layer.ogr().GetNextFeature();
if(feature.addr() == 0)
{
otbAppLogFATAL("Layer " << layerName << " of input sample file is empty.\n");
}
int cFieldIndex = feature.ogr().GetFieldIndex( classField.c_str() );
if( cFieldIndex < 0 )
{
otbAppLogFATAL( "The field name for class label (" << classField
<< ") has not been found in the vector file " );
}
auto numberOfFields = feature.ogr().GetFieldCount();
auto excludedIds = getExcludedFeaturesIds(excludedFeatures, layer);
otbAppLogINFO("The vector file contains " << numberOfFields << " fields.\n");
SampleVectorType samples;
bool goesOn{feature.addr() != 0};
while( goesOn )
{
// Retrieve all the features for each field in the ogr layer.
if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label)
{
SampleType mv;
for(auto idx=0; idx<numberOfFields; ++idx)
{
if(excludedIds.find(idx) == excludedIds.cend() &&
isNumericField(feature, idx))
mv.push_back(feature.ogr().GetFieldAsDouble(idx));
}
samples.push_back(mv);
}
feature = layer.ogr().GetNextFeature();
goesOn = feature.addr() != 0;
}
return samples;
}
void writeSamples(const ogr::DataSource::Pointer& vectors,
ogr::DataSource::Pointer& output,
const SampleVectorType& samples,
const size_t layerName,
const std::string& classField, int label,
const std::vector<std::string>& excludedFeatures = {})
{
auto inputLayer = vectors->GetLayer(layerName);
auto excludedIds = getExcludedFeaturesIds(excludedFeatures, inputLayer);
OGRSpatialReference * oSRS = nullptr;
if (inputLayer.GetSpatialRef())
{
oSRS = inputLayer.GetSpatialRef()->Clone();
}
OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn();
auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS,
inputLayer.GetGeomType());
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k));
ogr::FieldDefn fieldDefn(originDefn);
outputLayer.CreateField(fieldDefn);
}
auto featureCount = outputLayer.GetFeatureCount(false);
auto templateFeature = selectTemplateFeature(inputLayer, classField, label);
for(const auto& sample : samples)
{
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
dstFeature.SetFrom( templateFeature, TRUE );
dstFeature.SetFID(++featureCount);
auto sampleFieldCounter = 0;
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
if(excludedIds.find(k) == excludedIds.cend() &&
isNumericField(dstFeature, k))
{
dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]);
}
}
outputLayer.CreateFeature( dstFeature );
}
}
std::vector<std::string> GetExcludedFeatures(const std::vector<std::string>& fieldNames,
const std::vector<int>& selectedIdx)
......@@ -369,45 +275,8 @@ private:
}
return result;
}
ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer,
const std::string& classField, int label)
{
auto featureIt = inputLayer.begin();
bool goesOn{(*featureIt).addr() != 0};
while( goesOn )
{
if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label)
{
return *featureIt;
}
++featureIt;
}
return *(inputLayer.begin());
}
std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
const ogr::Layer& inputLayer)
{
auto feature = *(inputLayer).begin();
std::set<size_t> excludedIds;
if( excludedFeatures.size() != 0)
{
for(const auto& fieldName : excludedFeatures)
{
auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
excludedIds.insert(idx);
}
}
return excludedIds;
}
bool isNumericField(const ogr::Feature& feature,
const int idx)
{
OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType();
return (fieldType == OFTInteger
|| ogr::version_proxy::IsOFTInteger64( fieldType )
|| fieldType == OFTReal);
}
};
};
} // end of namespace Wrapper
} // end of namespace otb
......
......@@ -199,7 +199,7 @@ void smote(const SampleVectorType& inSamples,
}
}
}
}
}//end namespaces sampleAugmentation
}//end namespace otb
#endif
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSampleAugmentationFilter_h
#define otbSampleAugmentationFilter_h
#include "itkProcessObject.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbSampleAugmentation.h"
namespace otb
{
/** \class SampleAugmentationFilter
This class
*/
class ITK_EXPORT SampleAugmentationFilter :
public itk::ProcessObject
{
public:
/** typedef for the classes standards. */
typedef SampleAugmentationFilter Self;
typedef itk::ProcessObject Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Method for management of the object factory. */
itkNewMacro(Self);
/** Return the name of the class. */
itkTypeMacro(SampleAugmentationFilter, ProcessObject);
typedef ogr::DataSource OGRDataSourceType;
typedef typename OGRDataSourceType::Pointer OGRDataSourcePointerType;
typedef ogr::Layer OGRLayerType;
typedef itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;
using SampleType = sampleAugmentation::SampleType;
using SampleVectorType = sampleAugmentation::SampleVectorType;
enum class Strategy { Replicate, Jitter, Smote };
/** Set/Get the input OGRDataSource of this process object. */
using Superclass::SetInput;
virtual void SetInput(const OGRDataSourceType* ds);
const OGRDataSourceType* GetInput(unsigned int idx);
virtual void SetOutputSamples(ogr::DataSource* data);
/** Set the Field Name in which labels will be written. (default is "class")
* A field "ClassFieldName" of type integer is created in the output memory layer.
*/
itkSetMacro(ClassFieldName, std::string);
/**
* Return the Field name in which labels have been written.
*/
itkGetMacro(ClassFieldName, std::string);
itkSetMacro(Layer, size_t);
itkGetMacro(Layer, size_t);
itkSetMacro(Label, int);
itkGetMacro(Label, int);
void SetStrategy(Strategy s)
{
m_Strategy = s;
}
Strategy GetStrategy() const
{
return m_Strategy;
}
itkSetMacro(NumberOfSamples, int);
itkGetMacro(NumberOfSamples, int);
void SetExcludedFeatures(const std::vector<std::string>& ef)
{
m_ExcludedFeatures = ef;
}
std::vector<std::string> GetExcludedFeatures() const
{
return m_ExcludedFeatures;
}
itkSetMacro(StdFactor, double);
itkGetMacro(StdFactor, double);
itkSetMacro(SmoteNeighbors, size_t);
itkGetMacro(SmoteNeighbors, size_t);
itkSetMacro(Seed, int);
itkGetMacro(Seed, int);
/**
* Get the output \c ogr::DataSource which is a "memory" datasource.
*/
const OGRDataSourceType * GetOutput();
protected:
SampleAugmentationFilter();
~SampleAugmentationFilter() ITK_OVERRIDE {}
/** Generate Data method*/
void GenerateData() ITK_OVERRIDE;
/** DataObject pointer */
typedef itk::DataObject::Pointer DataObjectPointer;
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) ITK_OVERRIDE;
using Superclass::MakeOutput;
SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors,
size_t layerName,
const std::string& classField, const int label,
const std::vector<std::string>& excludedFeatures = {});
void sampleToOGRFeatures(const ogr::DataSource::Pointer& vectors,
ogr::DataSource* output,
const SampleVectorType& samples,
const size_t layerName,
const std::string& classField, int label,
const std::vector<std::string>& excludedFeatures = {});
std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
const ogr::Layer& inputLayer);
bool isNumericField(const ogr::Feature& feature, const int idx);
ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer,
const std::string& classField, int label);
private:
SampleAugmentationFilter(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
std::string m_ClassFieldName;
size_t m_Layer;
int m_Label;
std::vector<std::string> m_ExcludedFeatures;
Strategy m_Strategy;
int m_NumberOfSamples;
double m_StdFactor;
size_t m_SmoteNeighbors;
int m_Seed;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSampleAugmentationFilter.txx"
#endif
#endif
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSampleAugmentationFilter_txx
#define otbSampleAugmentationFilter_txx
#include "otbSampleAugmentationFilter.h"
#include "stdint.h" //needed for uintptr_t
namespace otb
{
SampleAugmentationFilter
::SampleAugmentationFilter() : m_ClassFieldName("class")
{
this->SetNumberOfRequiredInputs(1);
this->SetNumberOfRequiredOutputs(1);
this->ProcessObject::SetNthOutput(0, this->MakeOutput(0) );
}
typename SampleAugmentationFilter::DataObjectPointer
SampleAugmentationFilter
::MakeOutput(DataObjectPointerArraySizeType itkNotUsed(idx))
{
return static_cast< DataObjectPointer >(OGRDataSourceType::New().GetPointer());
}
const typename SampleAugmentationFilter::OGRDataSourceType *
SampleAugmentationFilter
::GetOutput()
{
return static_cast< const OGRDataSourceType *>(
this->ProcessObject::GetOutput(0));
}
void
SampleAugmentationFilter
::SetInput(const otb::ogr::DataSource* ds)
{
this->Superclass::SetNthInput(0, const_cast<otb::ogr::DataSource *>(ds));
}
const typename SampleAugmentationFilter::OGRDataSourceType *
SampleAugmentationFilter
::GetInput(unsigned int idx)
{
return static_cast<const OGRDataSourceType *>
(this->itk::ProcessObject::GetInput(idx));
}
void
SampleAugmentationFilter
::SetOutputSamples(ogr::DataSource* data)
{
this->SetNthOutput(0,data);
}
void
SampleAugmentationFilter
::GenerateData(void)
{
OGRDataSourcePointerType inputDS = dynamic_cast<OGRDataSourceType*>(this->itk::ProcessObject::GetInput(0));
auto outputDS = static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(1));
auto inSamples = this->extractSamples(inputDS, m_Layer,
m_ClassFieldName,
m_Label,
m_ExcludedFeatures);
SampleVectorType newSamples;
switch (m_Strategy)
{
case Strategy::Replicate:
{
sampleAugmentation::replicateSamples(inSamples, m_NumberOfSamples,
newSamples);
}
break;
case Strategy::Jitter:
{
sampleAugmentation::jitterSamples(inSamples, m_NumberOfSamples,
newSamples,
m_StdFactor,
m_Seed);
}
break;
case Strategy::Smote:
{
sampleAugmentation::smote(inSamples, m_NumberOfSamples,
newSamples,
m_SmoteNeighbors,
m_Seed);
}
break;
}
this->sampleToOGRFeatures(inputDS, outputDS, newSamples, m_Layer,
m_ClassFieldName,
m_Label,
m_ExcludedFeatures);
// this->SetNthOutput(0,outputDS);
}
/** Extracts the samples of a single class from the vector data to a
* vector and excludes some unwanted features.
*/
SampleAugmentationFilter::SampleVectorType
SampleAugmentationFilter
::extractSamples(const ogr::DataSource::Pointer vectors,
size_t layerName,
const std::string& classField, const int label,
const std::vector<std::string>& excludedFeatures)
{
ogr::Layer layer = vectors->GetLayer(layerName);
ogr::Feature feature = layer.ogr().GetNextFeature();
if(feature.addr() == 0)
{
itkExceptionMacro("Layer " << layerName << " of input sample file is empty.\n");
}
int cFieldIndex = feature.ogr().GetFieldIndex( classField.c_str() );
if( cFieldIndex < 0 )
{
itkExceptionMacro( "The field name for class label (" << classField
<< ") has not been found in the vector file " );
}
auto numberOfFields = feature.ogr().GetFieldCount();
auto excludedIds = this->getExcludedFeaturesIds(excludedFeatures, layer);
SampleVectorType samples;
bool goesOn{feature.addr() != 0};
while( goesOn )
{
// Retrieve all the features for each field in the ogr layer.
if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label)
{
SampleType mv;
for(auto idx=0; idx<numberOfFields; ++idx)
{
if(excludedIds.find(idx) == excludedIds.cend() &&
this->isNumericField(feature, idx))
mv.push_back(feature.ogr().GetFieldAsDouble(idx));
}
samples.push_back(mv);
}
feature = layer.ogr().GetNextFeature();
goesOn = feature.addr() != 0;
}
return samples;
}
void
SampleAugmentationFilter
::sampleToOGRFeatures(const ogr::DataSource::Pointer& vectors,
ogr::DataSource* output,
const SampleAugmentationFilter::SampleVectorType& samples,
const size_t layerName,
const std::string& classField, int label,
const std::vector<std::string>& excludedFeatures)
{
auto inputLayer = vectors->GetLayer(layerName);
auto excludedIds = this->getExcludedFeaturesIds(excludedFeatures, inputLayer);
OGRSpatialReference * oSRS = nullptr;
if (inputLayer.GetSpatialRef())
{
oSRS = inputLayer.GetSpatialRef()->Clone();
}
OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn();
auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS,
inputLayer.GetGeomType());
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k));
ogr::FieldDefn fieldDefn(originDefn);
outputLayer.CreateField(fieldDefn);
}
auto featureCount = outputLayer.GetFeatureCount(false);
auto templateFeature = this->selectTemplateFeature(inputLayer, classField, label);
for(const auto& sample : samples)
{
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
dstFeature.SetFrom( templateFeature, TRUE );
dstFeature.SetFID(++featureCount);
auto sampleFieldCounter = 0;
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
if(excludedIds.find(k) == excludedIds.cend() &&
this->isNumericField(dstFeature, k))
{
dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]);
}
}
outputLayer.CreateFeature( dstFeature );
}
}
std::set<size_t>
SampleAugmentationFilter
::getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
const ogr::Layer& inputLayer)
{
auto feature = *(inputLayer).begin();
std::set<size_t> excludedIds;
if( excludedFeatures.size() != 0)
{
for(const auto& fieldName : excludedFeatures)
{
auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
excludedIds.insert(idx);
}
}
return excludedIds;
}
bool
SampleAugmentationFilter
::isNumericField(const ogr::Feature& feature,
const int idx)
{
OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType();
return (fieldType == OFTInteger
|| ogr::version_proxy::IsOFTInteger64( fieldType )
|| fieldType == OFTReal);
}
ogr::Feature
SampleAugmentationFilter
::selectTemplateFeature(const ogr::Layer& inputLayer,
const std::string& classField, int label)
{
auto featureIt = inputLayer.begin();
bool goesOn{(*featureIt).addr() != 0};
while( goesOn )
{
if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label)
{
return *featureIt;
}
++featureIt;
}
return *(inputLayer.begin());
}
} // end namespace otb
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment