Skip to content
Snippets Groups Projects
Commit 825a0b3d authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: implementation of 3 augmentation algorithms

parent 3c437a4b
No related branches found
No related tags found
No related merge requests found
......@@ -21,6 +21,7 @@
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbSampleAugmentation.h"
namespace otb
{
......@@ -43,8 +44,8 @@ public:
itkTypeMacro(SampleAugmentation, otb::Application);
/** Filters typedef */
using SampleType = std::vector<double>;
using SampleVectorType = std::vector<SampleType>;
using SampleType = sampleAugmentation::SampleType;
using SampleVectorType = sampleAugmentation::SampleVectorType;
private:
......@@ -120,7 +121,7 @@ private:
ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
ogr::Feature feature = layer.ogr().GetNextFeature();
ClearChoices( "exclude" );
ClearChoices("exclude");
ClearChoices("field");
for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
......@@ -176,14 +177,21 @@ private:
std::vector<std::string> cFieldNames = GetChoiceNames("field");
std::string fieldName = cFieldNames[selectedCFieldIdx.front()];
std::vector<std::string> excludedFeatures = GetExcludedFeatures( GetChoiceNames( "exclude" ), GetSelectedItems( "exclude" ));
std::vector<std::string> excludedFeatures =
GetExcludedFeatures( GetChoiceNames( "exclude" ),
GetSelectedItems( "exclude" ));
for(const auto& ef : excludedFeatures)
std::cout << ef << " excluded\n";
otbAppLogINFO("Excluding feature " << ef << '\n');
auto inSamples = extractSamples(vectors, this->GetParameterInt("layer"),
fieldName,
this->GetParameterInt("label"),
excludedFeatures);
auto newSamples = augmentSamples(inSamples, this->GetParameterInt("samples"));
SampleVectorType newSamples;
// sampleAugmentation::replicateSamples(inSamples, this->GetParameterInt("samples"),
// newSamples);
sampleAugmentation::smote(inSamples, this->GetParameterInt("samples"),
newSamples,
4);
writeSamples(vectors, output, newSamples, this->GetParameterInt("layer"),
fieldName,
this->GetParameterInt("label"),
......@@ -194,8 +202,9 @@ private:
/** Extracts the samples of a single class from the vector data to a
* vector and excludes some unwanted features.
*/
SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors, size_t layerName,
std::string classField, int label,
SampleVectorType extractSamples(const ogr::DataSource::Pointer vectors,
size_t layerName,
const std::string& classField, const int label,
const std::vector<std::string>& excludedFeatures = {})
{
ogr::Layer layer = vectors->GetLayer(layerName);
......@@ -212,15 +221,7 @@ private:
}
auto numberOfFields = feature.ogr().GetFieldCount();
std::set<size_t> excludedIds;
if( excludedFeatures.size() != 0)
{
for(const auto& fieldName : excludedFeatures)
{
auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
excludedIds.insert(idx);
}
}
auto excludedIds = getExcludedFeaturesIds(excludedFeatures, layer);
otbAppLogINFO("The vector file contains " << numberOfFields << " fields.\n");
SampleVectorType samples;
bool goesOn{feature.addr() != 0};
......@@ -229,14 +230,12 @@ private:
// Retrieve all the features for each field in the ogr layer.
if(feature.ogr().GetFieldAsInteger(classField.c_str()) == label)
{
SampleType mv;
for(auto idx=0; idx<numberOfFields; ++idx)
{
OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType();
if(excludedIds.find(idx) == excludedIds.cend() &&
(fieldType == OFTInteger
|| ogr::version_proxy::IsOFTInteger64( fieldType )
|| fieldType == OFTReal))
isNumericField(feature, idx))
mv.push_back(feature.ogr().GetFieldAsDouble(idx));
}
samples.push_back(mv);
......@@ -247,37 +246,16 @@ private:
return samples;
}
SampleVectorType augmentSamples(const SampleVectorType& inSamples,
const size_t nbSamples)
{
SampleVectorType newSamples;
for(size_t i=0; i<nbSamples; ++i)
{
newSamples.push_back(inSamples[i%inSamples.size()]);
}
return newSamples;
}
void writeSamples(const ogr::DataSource::Pointer vectors,
ogr::DataSource::Pointer output,
void writeSamples(const ogr::DataSource::Pointer& vectors,
ogr::DataSource::Pointer& output,
const SampleVectorType& samples,
size_t layerName,
std::string classField, int label,
const size_t layerName,
const std::string& classField, int label,
const std::vector<std::string>& excludedFeatures = {})
{
auto inputLayer = vectors->GetLayer(layerName);
std::set<size_t> excludedIds;
if( excludedFeatures.size() != 0)
{
auto feature = *(inputLayer).begin();
for(const auto& fieldName : excludedFeatures)
{
auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
excludedIds.insert(idx);
}
}
auto excludedIds = getExcludedFeaturesIds(excludedFeatures, inputLayer);
OGRSpatialReference * oSRS = nullptr;
if (inputLayer.GetSpatialRef())
......@@ -296,7 +274,7 @@ private:
}
auto featureCount = outputLayer.GetFeatureCount(false);
auto templateFeature = *(inputLayer).begin();
auto templateFeature = selectTemplateFeature(inputLayer, classField, label);
for(const auto& sample : samples)
{
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
......@@ -305,27 +283,18 @@ private:
auto sampleFieldCounter = 0;
for (int k=0 ; k < layerDefn.GetFieldCount() ; k++)
{
OGRFieldType fieldType = dstFeature.ogr().GetFieldDefnRef(k)->GetType();
if(excludedIds.find(k) == excludedIds.cend() &&
(fieldType == OFTInteger
|| ogr::version_proxy::IsOFTInteger64( fieldType )
|| fieldType == OFTReal))
isNumericField(dstFeature, k))
{
dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]);
}
}
// for (unsigned int i=0 ; i<nbBand ; ++i)
// {
// imgComp = static_cast<double>(itk::DefaultConvertPixelTraits<PixelType>::GetNthComponent(i,imgPixel));
// // Fill the output OGRDataSource
// dstFeature[m_SampleFieldNames[i]].SetValue(imgComp);
// }
outputLayer.CreateFeature( dstFeature );
}
}
std::vector<std::string> GetExcludedFeatures(std::vector <std::string> fieldNames,
std::vector<int> selectedIdx)
std::vector<std::string> GetExcludedFeatures(const std::vector<std::string>& fieldNames,
const std::vector<int>& selectedIdx)
{
auto nbFeatures = static_cast<unsigned int>(selectedIdx.size());
std::vector<std::string> result( nbFeatures );
......@@ -335,7 +304,45 @@ private:
}
return result;
}
};
ogr::Feature selectTemplateFeature(const ogr::Layer& inputLayer,
const std::string& classField, int label)
{
auto featureIt = inputLayer.begin();
bool goesOn{(*featureIt).addr() != 0};
while( goesOn )
{
if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label)
{
return *featureIt;
}
++featureIt;
}
return *(inputLayer.begin());
}
std::set<size_t> getExcludedFeaturesIds(const std::vector<std::string>& excludedFeatures,
const ogr::Layer& inputLayer)
{
auto feature = *(inputLayer).begin();
std::set<size_t> excludedIds;
if( excludedFeatures.size() != 0)
{
for(const auto& fieldName : excludedFeatures)
{
auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() );
excludedIds.insert(idx);
}
}
return excludedIds;
}
bool isNumericField(const ogr::Feature& feature,
const int idx)
{
OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType();
return (fieldType == OFTInteger
|| ogr::version_proxy::IsOFTInteger64( fieldType )
|| fieldType == OFTReal);
}
};
} // end of namespace Wrapper
} // end of namespace otb
......
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbSampleAugmentation_h
#define otbSampleAugmentation_h
#include <vector>
#include <algorithm>
#include <random>
#include <ctime>
#include <cassert>
#include <iostream>
namespace otb
{
namespace sampleAugmentation
{
using SampleType = std::vector<double>;
using SampleVectorType = std::vector<SampleType>;
/**
Estimate standard deviations of the components in one pass using
Welford's algorithm
*/
SampleType estimateStds(SampleVectorType samples)
{
const auto nbSamples = samples.size();
const auto nbComponents = samples[0].size();
SampleType stds(nbComponents, 0.0);
SampleType means(nbComponents, 0.0);
for(size_t i=0; i<nbSamples; ++i)
{
for(size_t j=0; j<nbComponents; ++j)
{
const auto mu = means[j];
const auto x = samples[i][j];
auto muNew = mu+(x-mu)/(i+1);
stds[j] += (x-mu)*(x-muNew);
means[j] = muNew;
}
}
for(auto std : stds)
std = std::sqrt(std/nbSamples);
return stds;
}
/** Create new samples by replicating input samples. We loop through
* the input samples and add them to the new data set until nbSamples
* are added. The elements of newSamples are removed before proceeding.
*/
void replicateSamples(const SampleVectorType& inSamples,
const size_t nbSamples,
SampleVectorType& newSamples)
{
newSamples.resize(nbSamples);
for(size_t i=0; i<nbSamples; ++i)
{
newSamples[i] = inSamples[i%inSamples.size()];
}
}
/** Create new samples by adding noise to existing samples. Gaussian
* noise is added to randomly selected samples. The standard deviation
* of the noise added to each component is the same as the one of the
* input variables multiplied by stdFactor (defaults to 1). The
* elements of newSamples are removed before proceeding.
*/
void jitterSamples(const SampleVectorType& inSamples,
const size_t nbSamples,
SampleVectorType& newSamples,
float stdFactor=1.0,
const int seed = std::time(nullptr))
{
newSamples.resize(nbSamples);
const auto nbComponents = inSamples[0].size();
std::random_device rd;
std::mt19937 gen(rd());
// The input samples are selected randomly with replacement
std::srand(seed);
// We use one gaussian distribution per component since they may
// have different stds
auto stds = estimateStds(inSamples);
std::vector<std::normal_distribution<double>> gaussDis;
for(size_t i=0; i<nbComponents; ++i)
gaussDis.emplace_back(std::normal_distribution<double>{0.0, stds[i]*stdFactor});
for(size_t i=0; i<nbSamples; ++i)
{
newSamples[i] = inSamples[std::rand()%nbSamples];
for(size_t j=0; j<nbComponents; ++j)
newSamples[i][j] += gaussDis[j](gen);
}
}
struct NeighborType
{
size_t index;
double distance;
};
struct NeighborSorter
{
constexpr bool operator ()(const NeighborType& a, const NeighborType& b) const
{
return b.distance > a.distance;
}
};
double computeDistance(const SampleType& x, const SampleType& y)
{
assert(x.size()==y.size());
double dist{0};
for(size_t i=0; i<x.size(); ++i)
{
dist += (x[i]-y[i])*(x[i]-y[i])/(x.size()*x.size());
}
return std::sqrt(dist);
}
using NNIndicesType = std::vector<NeighborType>;
using NNVectorType = std::vector<NNIndicesType>;
/** Returns the indices of the nearest neighbors for each input sample
*/
void findKNNIndices(const SampleVectorType& inSamples,
const size_t nbNeighbors,
NNVectorType& nnVector)
{
const auto nbSamples = inSamples.size();
nnVector.resize(nbSamples);
for(size_t sampleIdx=0; sampleIdx<nbSamples; ++sampleIdx)
{
NNIndicesType nns;
for(size_t neighborIdx=0; neighborIdx<nbSamples; ++neighborIdx)
{
if(sampleIdx!=neighborIdx)
nns.push_back({neighborIdx, computeDistance(inSamples[sampleIdx],
inSamples[neighborIdx])});
}
std::partial_sort(nns.begin(), nns.begin()+nbNeighbors, nns.end(), NeighborSorter{});
nns.resize(nbNeighbors);
nnVector[sampleIdx] = nns;
}
}
/** Generate the new sample in the line linking s1 and s2
*/
SampleType smoteCombine(SampleType s1, SampleType s2, double position)
{
auto result = s1;
for(size_t i=0; i<s1.size(); ++i)
result[i] = s1[i]+(s2[i]-s1[i])*position;
return result;
}
/** Create new samples using the SMOTE algorithm
Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P., Smote:
synthetic minority over-sampling technique, Journal of artificial
intelligence research, 16(), 321–357 (2002).
http://dx.doi.org/10.1613/jair.953
*/
void smote(const SampleVectorType& inSamples,
const size_t nbSamples,
SampleVectorType& newSamples,
const int nbNeighbors,
const int seed = std::time(nullptr))
{
newSamples.resize(nbSamples);
NNVectorType nnVector;
findKNNIndices(inSamples, nbNeighbors, nnVector);
// The input samples are selected randomly with replacement
std::srand(seed);
for(size_t i=0; i<nbSamples; ++i)
{
const auto sampleIdx = std::rand()%nbSamples;
const auto sample = inSamples[sampleIdx];
const auto neighborIdx = nnVector[sampleIdx][std::rand()%nbNeighbors].index;
const auto neighbor = inSamples[neighborIdx];
newSamples[i] = smoteCombine(sample, neighbor, std::rand()/double{RAND_MAX});
}
}
}
}
#endif
......@@ -981,5 +981,5 @@ otb_test_application(NAME apTvClSampleAugmentation
-label 3
-samples 100
-out ${TEMP}/apTvClSampleAugmentation.sqlite
# -excluded_features OGC_FID name class originfid
-exclude originfid
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment