Commit 1c6bb187 authored by Julien Michel's avatar Julien Michel

Merge branch 'data-augmentation' into 'develop'

Data augmentation

See merge request !25
parents e54bbf65 48c5bbf2
......@@ -86,6 +86,7 @@ const ExtensionDriverAssociation k_ExtensionDriverMap[] =
{".GPX", "GPX"},
{".SQLITE", "SQLite"},
{".KML", "KML"},
{".CSV", "CSV"},
};
/**\ingroup GeometryInternals
* \brief Returns the OGR driver name associated to a filename.
......
......@@ -125,5 +125,10 @@ otb_create_application(
SOURCES otbVectorClassifier.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES})
otb_create_application(
NAME SampleAugmentation
SOURCES otbSampleAugmentation.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES})
# Mantis-1427 : temporary fix
add_dependencies(${otb-module}-all otbapp_ImageEnvelope)
......@@ -972,3 +972,39 @@ otb_test_application(
${OTBAPP_BASELINE_FILES}/apTvClMultiImageSamplingRate_out_3.csv
${TEMP}/apTvClMultiImageSamplingRate_out_3.csv
)
#------------ SampleAgmentation TESTS ----------------
otb_test_application(NAME apTvClSampleAugmentationReplicate
APP SampleAugmentation
OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite
-field class
-label 3
-samples 100
-out ${TEMP}/apTvClSampleAugmentationReplicate.sqlite
-exclude originfid
-strategy replicate
)
otb_test_application(NAME apTvClSampleAugmentationJitter
APP SampleAugmentation
OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite
-field class
-label 3
-samples 100
-out ${TEMP}/apTvClSampleAugmentationJitter.sqlite
-exclude originfid
-strategy jitter
-strategy.jitter.stdfactor 10
)
otb_test_application(NAME apTvClSampleAugmentationSmote
APP SampleAugmentation
OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite
-field class
-label 3
-samples 100
-out ${TEMP}/apTvClSampleAugmentationSmote.sqlite
-exclude originfid
-strategy smote
-strategy.smote.neighbors 5
)
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSampleAugmentation_h
#define otbSampleAugmentation_h
#include <vector>
#include <algorithm>
#include <random>
#include <ctime>
#include <cassert>
#include <iostream>
namespace otb
{
namespace sampleAugmentation
{
using SampleType = std::vector<double>;
using SampleVectorType = std::vector<SampleType>;
/**
Estimate standard deviations of the components in one pass using
Welford's algorithm
*/
SampleType EstimateStds(const SampleVectorType& samples)
{
const auto nbSamples = samples.size();
const auto nbComponents = samples[0].size();
SampleType stds(nbComponents, 0.0);
SampleType means(nbComponents, 0.0);
for(size_t i=0; i<nbSamples; ++i)
{
auto norm_factor = 1.0/(i+1);
#pragma omp parallel for
for(size_t j=0; j<nbComponents; ++j)
{
const auto mu = means[j];
const auto x = samples[i][j];
auto muNew = mu+(x-mu)*norm_factor;
stds[j] += (x-mu)*(x-muNew);
means[j] = muNew;
}
}
#pragma omp parallel for
for(size_t j=0; j<nbComponents; ++j)
{
stds[j] = std::sqrt(stds[j]/nbSamples);
}
return stds;
}
/** Create new samples by replicating input samples. We loop through
* the input samples and add them to the new data set until nbSamples
* are added. The elements of newSamples are removed before proceeding.
*/
void ReplicateSamples(const SampleVectorType& inSamples,
const size_t nbSamples,
SampleVectorType& newSamples)
{
newSamples.resize(nbSamples);
size_t imod{0};
#pragma omp parallel for
for(size_t i=0; i<nbSamples; ++i)
{
if (imod == inSamples.size()) imod = 0;
newSamples[i] = inSamples[imod++];
}
}
/** Create new samples by adding noise to existing samples. Gaussian
* noise is added to randomly selected samples. The standard deviation
* of the noise added to each component is the same as the one of the
* input variables divided by stdFactor (defaults to 10). The
* elements of newSamples are removed before proceeding.
*/
void JitterSamples(const SampleVectorType& inSamples,
const size_t nbSamples,
SampleVectorType& newSamples,
float stdFactor=10,
const int seed = std::time(nullptr))
{
newSamples.resize(nbSamples);
const auto nbComponents = inSamples[0].size();
std::random_device rd;
std::mt19937 gen(rd());
// The input samples are selected randomly with replacement
std::srand(seed);
// We use one gaussian distribution per component since they may
// have different stds
auto stds = EstimateStds(inSamples);
std::vector<std::normal_distribution<double>> gaussDis(nbComponents);
#pragma omp parallel for
for(size_t i=0; i<nbComponents; ++i)
gaussDis[i] = std::normal_distribution<double>{0.0, stds[i]/stdFactor};
for(size_t i=0; i<nbSamples; ++i)
{
newSamples[i] = inSamples[std::rand()%inSamples.size()];
#pragma omp parallel for
for(size_t j=0; j<nbComponents; ++j)
newSamples[i][j] += gaussDis[j](gen);
}
}
struct NeighborType
{
size_t index;
double distance;
};
struct NeighborSorter
{
constexpr bool operator ()(const NeighborType& a, const NeighborType& b) const
{
return b.distance > a.distance;
}
};
double ComputeSquareDistance(const SampleType& x, const SampleType& y)
{
assert(x.size()==y.size());
double dist{0};
for(size_t i=0; i<x.size(); ++i)
{
dist += (x[i]-y[i])*(x[i]-y[i]);
}
return dist/(x.size()*x.size());
}
using NNIndicesType = std::vector<NeighborType>;
using NNVectorType = std::vector<NNIndicesType>;
/** Returns the indices of the nearest neighbors for each input sample
*/
void FindKNNIndices(const SampleVectorType& inSamples,
const size_t nbNeighbors,
NNVectorType& nnVector)
{
const auto nbSamples = inSamples.size();
nnVector.resize(nbSamples);
#pragma omp parallel for
for(size_t sampleIdx=0; sampleIdx<nbSamples; ++sampleIdx)
{
NNIndicesType nns;
for(size_t neighborIdx=0; neighborIdx<nbSamples; ++neighborIdx)
{
if(sampleIdx!=neighborIdx)
nns.push_back({neighborIdx, ComputeSquareDistance(inSamples[sampleIdx],
inSamples[neighborIdx])});
}
std::partial_sort(nns.begin(), nns.begin()+nbNeighbors, nns.end(), NeighborSorter{});
nns.resize(nbNeighbors);
nnVector[sampleIdx] = std::move(nns);
}
}
/** Generate the new sample in the line linking s1 and s2
*/
SampleType SmoteCombine(const SampleType& s1, const SampleType& s2, double position)
{
auto result = s1;
for(size_t i=0; i<s1.size(); ++i)
result[i] = s1[i]+(s2[i]-s1[i])*position;
return result;
}
/** Create new samples using the SMOTE algorithm
Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P., Smote:
synthetic minority over-sampling technique, Journal of artificial
intelligence research, 16(), 321–357 (2002).
http://dx.doi.org/10.1613/jair.953
*/
void Smote(const SampleVectorType& inSamples,
const size_t nbSamples,
SampleVectorType& newSamples,
const int nbNeighbors,
const int seed = std::time(nullptr))
{
newSamples.resize(nbSamples);
NNVectorType nnVector;
FindKNNIndices(inSamples, nbNeighbors, nnVector);
// The input samples are selected randomly with replacement
std::srand(seed);
#pragma omp parallel for
for(size_t i=0; i<nbSamples; ++i)
{
const auto sampleIdx = std::rand()%(inSamples.size());
const auto sample = inSamples[sampleIdx];
const auto neighborIdx = nnVector[sampleIdx][std::rand()%nbNeighbors].index;
const auto neighbor = inSamples[neighborIdx];
newSamples[i] = SmoteCombine(sample, neighbor, std::rand()/double{RAND_MAX});
}
}
}//end namespaces sampleAugmentation
}//end namespace otb
#endif
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSampleAugmentationFilter_h
#define otbSampleAugmentationFilter_h
#include "itkProcessObject.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbSampleAugmentation.h"
namespace otb
{
/**
* \class SampleAugmentationFilter
*
* \brief Filter to generate synthetic samples from existing ones
*
* This class generates synthetic samples from existing ones either by
* replication, jitter (adding gaussian noise to the features of
* existing samples) or SMOTE (linear combination of pairs
* neighbouring samples of the same class.
*
* \ingroup OTBSampling
*/
class ITK_EXPORT SampleAugmentationFilter :
public itk::ProcessObject
{
public:
/** typedef for the classes standards. */
typedef SampleAugmentationFilter Self;
typedef itk::ProcessObject Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Method for management of the object factory. */
itkNewMacro(Self);
/** Return the name of the class. */
itkTypeMacro(SampleAugmentationFilter, ProcessObject);
typedef ogr::DataSource OGRDataSourceType;
typedef typename OGRDataSourceType::Pointer OGRDataSourcePointerType;
typedef ogr::Layer OGRLayerType;
typedef itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;
using SampleType = sampleAugmentation::SampleType;
using SampleVectorType = sampleAugmentation::SampleVectorType;
enum class Strategy { Replicate, Jitter, Smote };
/** Set/Get the input OGRDataSource of this process object. */
using Superclass::SetInput;
virtual void SetInput(const OGRDataSourceType* ds);
const OGRDataSourceType* GetInput(unsigned int idx);
virtual void SetOutputSamples(ogr::DataSource* data);
/** Set the Field Name in which labels will be written. (default is "class")
* A field "ClassFieldName" of type integer is created in the output memory layer.
*/
itkSetMacro(ClassFieldName, std::string);
/**
* Return the Field name in which labels have been written.
*/
itkGetMacro(ClassFieldName, std::string);
itkSetMacro(Layer, size_t);
itkGetMacro(Layer, size_t);
itkSetMacro(Label, int);
itkGetMacro(Label, int);
void SetStrategy(Strategy s)
{
m_Strategy = s;
}
Strategy GetStrategy() const
{
return m_Strategy;
}
itkSetMacro(NumberOfSamples, int);
itkGetMacro(NumberOfSamples, int);
void SetExcludedFields(const std::vector<std::string>& ef)
{
m_ExcludedFields = ef;
}
std::vector<std::string> GetExcludedFields() const
{
return m_ExcludedFields;
}
itkSetMacro(StdFactor, double);
itkGetMacro(StdFactor, double);
itkSetMacro(SmoteNeighbors, size_t);
itkGetMacro(SmoteNeighbors, size_t);
itkSetMacro(Seed, int);
itkGetMacro(Seed, int);
/**
* Get the output \c ogr::DataSource which is a "memory" datasource.
*/
const OGRDataSourceType * GetOutput();
protected:
SampleAugmentationFilter();
~SampleAugmentationFilter() ITK_OVERRIDE {}
/** Generate Data method*/
void GenerateData() ITK_OVERRIDE;
/** DataObject pointer */
typedef itk::DataObject::Pointer DataObjectPointer;
DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) ITK_OVERRIDE;
using Superclass::MakeOutput;
SampleVectorType ExtractSamples(const ogr::DataSource::Pointer vectors,
size_t layerName,
const std::string& classField, const int label,
const std::vector<std::string>& excludedFields = {});
void SampleToOGRFeatures(const ogr::DataSource::Pointer& vectors,
ogr::DataSource* output,
const SampleVectorType& samples,
const size_t layerName,
const std::string& classField, int label,
const std::vector<std::string>& excludedFields = {});
std::set<size_t> GetExcludedFieldsIds(const std::vector<std::string>& excludedFields,
const ogr::Layer& inputLayer);
bool IsNumericField(const ogr::Feature& feature, const int idx);
ogr::Feature SelectTemplateFeature(const ogr::Layer& inputLayer,
const std::string& classField, int label);
private:
SampleAugmentationFilter(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
std::string m_ClassFieldName;
size_t m_Layer;
int m_Label;
std::vector<std::string> m_ExcludedFields;
Strategy m_Strategy;
int m_NumberOfSamples;
double m_StdFactor;
size_t m_SmoteNeighbors;
int m_Seed;
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbSampleAugmentationFilter.txx"
#endif
#endif
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSampleAugmentationFilter_txx
#define otbSampleAugmentationFilter_txx
#include "otbSampleAugmentationFilter.h"
#include "stdint.h" //needed for uintptr_t
namespace otb
{
SampleAugmentationFilter
::SampleAugmentationFilter() : m_ClassFieldName{"class"}, m_Layer{0}, m_Label{1},
m_Strategy{SampleAugmentationFilter::Strategy::Replicate},
m_NumberOfSamples{100}, m_StdFactor{10.0},
m_SmoteNeighbors{5}, m_Seed{0}
{
this->SetNumberOfRequiredInputs(1);
this->SetNumberOfRequiredOutputs(1);
this->ProcessObject::SetNthOutput(0, this->MakeOutput(0) );
}
typename SampleAugmentationFilter::DataObjectPointer
SampleAugmentationFilter
::MakeOutput(DataObjectPointerArraySizeType itkNotUsed(idx))
{
return static_cast< DataObjectPointer >(OGRDataSourceType::New().GetPointer());
}
const typename SampleAugmentationFilter::OGRDataSourceType *
SampleAugmentationFilter
::GetOutput()
{
return static_cast< const OGRDataSourceType *>(
this->ProcessObject::GetOutput(0));
}
void
SampleAugmentationFilter
::SetInput(const otb::ogr::DataSource* ds)
{
this->Superclass::SetNthInput(0, const_cast<otb::ogr::DataSource *>(ds));
}
const typename SampleAugmentationFilter::OGRDataSourceType *
SampleAugmentationFilter
::GetInput(unsigned int idx)
{
return static_cast<const OGRDataSourceType *>
(this->itk::ProcessObject::GetInput(idx));
}
void
SampleAugmentationFilter
::SetOutputSamples(ogr::DataSource* data)
{
this->SetNthOutput(0,data);
}
void
SampleAugmentationFilter
::GenerateData(void)
{
OGRDataSourcePointerType inputDS = dynamic_cast<OGRDataSourceType*>(this->itk::ProcessObject::GetInput(0));
auto outputDS = static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(0));
auto inSamples = this->ExtractSamples(inputDS, m_Layer,
m_ClassFieldName,
m_Label,
m_ExcludedFields);
SampleVectorType newSamples;
switch (m_Strategy)
{
case Strategy::Replicate:
{
sampleAugmentation::ReplicateSamples(inSamples, m_NumberOfSamples,
newSamples);
}
break;
case Strategy::Jitter:
{
sampleAugmentation::JitterSamples(inSamples, m_NumberOfSamples,
newSamples,
m_StdFactor,
m_Seed);
}
break;
case Strategy::Smote:
{
sampleAugmentation::Smote(inSamples, m_NumberOfSamples,
newSamples,
m_SmoteNeighbors,
m_Seed);
}
break;
}
this->SampleToOGRFeatures(inputDS, outputDS, newSamples, m_Layer,
m_ClassFieldName,
m_Label,
m_ExcludedFields);
// this->SetNthOutput(0,outputDS);
}
/** Extracts the samples of a single class from the vector data to a
* vector and excludes some unwanted features.
*/
SampleAugmentationFilter::SampleVectorType
SampleAugmentationFilter
::ExtractSamples(const ogr::DataSource::Pointer vectors,
size_t layerName,
const std::string& classField, const int label,
const std::vector<std::string>& excludedFields)
{
ogr::Layer layer = vectors->GetLayer(layerName);
auto featureIt = layer.begin();
if(featureIt==layer.end())
{
itkExceptionMacro("Layer " << layerName << " of input sample file is empty.\n");
}
int cFieldIndex = (*featureIt).ogr().GetFieldIndex( classField.c_str() );
if( cFieldIndex < 0 )
{
itkExceptionMacro( "The field name for class label (" << classField
<< ") has not been found in the vector file " );
}
auto numberOfFields = (*featureIt).ogr().GetFieldCount();
auto excludedIds = this->GetExcludedFieldsIds(excludedFields, layer);
SampleVectorType samples;
int sampleCount{0};
while( featureIt!=layer.end() )
{
// Retrieve all the features for each field in the ogr layer.
if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label)
{
SampleType mv;
for(auto idx=0; idx<numberOfFields; ++idx)
{
if(excludedIds.find(idx) == excludedIds.cend() &&
this->IsNumericField((*featureIt), idx))
mv.push_back((*featureIt).ogr().GetFieldAsDouble(idx));
}
samples.push_back(mv);
++sampleCount;
}
++featureIt;
}
if(sampleCount==0)
{
itkExceptionMacro("Could not find any samples in layer " << layerName <<
" with label " << label << '\n');
}
return samples;
}
void
SampleAugmentationFilter
::SampleToOGRFeatures(const ogr::DataSource::Pointer& vectors,
ogr::DataSource* output,
const SampleAugmentationFilter::SampleVectorType& samples,
const size_t layerName,
const std::string& classField, int label,
const std::vector<std::string>& excludedFields)
{
auto inputLayer = vectors->GetLayer(layerName);
auto excludedIds = this->GetExcludedFieldsIds(excludedFields, inputLayer);
OGRSpatialReference * oSRS = nullptr;
if (inputLayer.GetSpatialRef())
{
oSRS = inputLayer.GetSpatialRef()->Clone();
}
OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn();
auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS,
inputLayer.GetGeomType());
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k));
ogr::FieldDefn fieldDefn(originDefn);
outputLayer.CreateField(fieldDefn);
}
auto featureCount = outputLayer.GetFeatureCount(false);
auto templateFeature = this->SelectTemplateFeature(inputLayer, classField, label);
for(const auto& sample : samples)
{
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
dstFeature.SetFrom( templateFeature, TRUE );
dstFeature.SetFID(++featureCount);
auto sampleFieldCounter = 0;
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
if(excludedIds.find(k) == excludedIds.cend() &&
this->IsNumericField(dstFeature, k))
{
dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]);
}
}
outputLayer.CreateFeature( dstFeature );
}
}
std::set<size_t>
SampleAugmentationFilter
::GetExcludedFieldsIds(const std::vector<std::string>& excludedFields,
const ogr::Layer& inputLayer)
{
auto feature = *(inputLayer).begin();
std::set<size_t> excludedIds;
if( excludedFields.size() != 0)
{
for(const auto& fieldName : excludedFields)
{
auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
excludedIds.insert(idx);