From b0a4f4a57ba46d0b18d84acf21bd4cbdecf3ad94 Mon Sep 17 00:00:00 2001 From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr> Date: Tue, 27 Feb 2018 17:35:35 +0100 Subject: [PATCH] ENH: add options for the different strategies --- .../app/otbSampleAugmentation.cxx | 75 +++++++++++++++++-- .../include/otbSampleAugmentation.h | 6 +- .../AppClassification/test/CMakeLists.txt | 29 ++++++- 3 files changed, 101 insertions(+), 9 deletions(-) diff --git a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx index 36fc02f8bf..a12e3912b1 100644 --- a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx +++ b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx @@ -99,6 +99,42 @@ private: SetParameterDescription("exclude", "List of field names in the input vector data that will not be generated in the output file."); + AddParameter(ParameterType_Choice, "strategy", "Augmentation strategy"); + + AddChoice("strategy.replicate","Replicate input samples"); + SetParameterDescription("strategy.replicate","The new samples are generated " + "by replicating input samples which are randomly " + "selected with replacement."); + + AddChoice("strategy.jitter","Jitter input samples"); + SetParameterDescription("strategy.jitter","The new samples are generated " + "by adding gaussian noise to input samples which are " + "randomly selected with replacement."); + AddParameter(ParameterType_Float, "strategy.jitter.stdfactor", + "Factor for dividing the standard deviation of each feature"); + SetParameterDescription("strategy.jitter.stdfactor", + "The noise added to the input samples will have the " + "standard deviation of the input features divided " + "by the value of this parameter. "); + SetDefaultParameterFloat("strategy.jitter.stdfactor",10000); + + AddChoice("strategy.smote","Smote input samples"); + SetParameterDescription("strategy.smote","The new samples are generated " + "by using the SMOTE algorithm (http://dx.doi.org/10.1613/jair.953) " + "on input samples which are " + "randomly selected with replacement."); + AddParameter(ParameterType_Int, "strategy.smote.neighbors", + "Number of nearest neighbors."); + SetParameterDescription("strategy.smote.neighbors", + "Number of nearest neighbors to be used in the " + "SMOTE algorithm"); + SetDefaultParameterFloat("strategy.smote.neighbors", 5); + + AddParameter(ParameterType_Int, "seed", + "Random seed."); + SetParameterDescription("seed", + "Seed for the random number generator."); + MandatoryOff("seed"); // Doc example parameter settings SetDocExampleParameterValue("in", "samples.sqlite"); @@ -107,6 +143,8 @@ private: SetDocExampleParameterValue("samples", "100"); SetDocExampleParameterValue("out","augmented_samples.sqlite"); SetDocExampleParameterValue( "exclude", "OGC_FID name class originfid" ); + SetDocExampleParameterValue("strategy", "smote"); + SetDocExampleParameterValue("strategy.smote.neighbors", "5"); SetOfficialDocLink(); } @@ -186,12 +224,39 @@ private: fieldName, this->GetParameterInt("label"), excludedFeatures); + int seed = std::time(nullptr); + if(IsParameterEnabled("seed")) seed = this->GetParameterInt("seed"); SampleVectorType newSamples; - // sampleAugmentation::replicateSamples(inSamples, this->GetParameterInt("samples"), - // newSamples); - sampleAugmentation::smote(inSamples, this->GetParameterInt("samples"), - newSamples, - 4); + switch (this->GetParameterInt("strategy")) + { + // replicate + case 0: + { + otbAppLogINFO("Augmentation strategy : replicate"); + sampleAugmentation::replicateSamples(inSamples, this->GetParameterInt("samples"), + newSamples); + } + break; + // jitter + case 1: + { + otbAppLogINFO("Augmentation strategy : jitter"); + sampleAugmentation::jitterSamples(inSamples, this->GetParameterInt("samples"), + newSamples, + this->GetParameterFloat("strategy.jitter.stdfactor"), + seed); + } + break; + case 2: + { + otbAppLogINFO("Augmentation strategy : smote"); + sampleAugmentation::smote(inSamples, this->GetParameterInt("samples"), + newSamples, + this->GetParameterInt("strategy.smote.neighbors"), + seed); + } + break; + } writeSamples(vectors, output, newSamples, this->GetParameterInt("layer"), fieldName, this->GetParameterInt("label"), diff --git a/Modules/Applications/AppClassification/include/otbSampleAugmentation.h b/Modules/Applications/AppClassification/include/otbSampleAugmentation.h index 27b9f07847..43fd6657a0 100644 --- a/Modules/Applications/AppClassification/include/otbSampleAugmentation.h +++ b/Modules/Applications/AppClassification/include/otbSampleAugmentation.h @@ -58,7 +58,9 @@ SampleType estimateStds(SampleVectorType samples) } } for(auto std : stds) + { std = std::sqrt(std/nbSamples); + } return stds; } @@ -86,7 +88,7 @@ void replicateSamples(const SampleVectorType& inSamples, void jitterSamples(const SampleVectorType& inSamples, const size_t nbSamples, SampleVectorType& newSamples, - float stdFactor=1.0, + float stdFactor=10000, const int seed = std::time(nullptr)) { newSamples.resize(nbSamples); @@ -100,7 +102,7 @@ void jitterSamples(const SampleVectorType& inSamples, auto stds = estimateStds(inSamples); std::vector<std::normal_distribution<double>> gaussDis; for(size_t i=0; i<nbComponents; ++i) - gaussDis.emplace_back(std::normal_distribution<double>{0.0, stds[i]*stdFactor}); + gaussDis.emplace_back(std::normal_distribution<double>{0.0, stds[i]/stdFactor}); for(size_t i=0; i<nbSamples; ++i) { newSamples[i] = inSamples[std::rand()%nbSamples]; diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index c0ec37fea3..c134b6b355 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -974,12 +974,37 @@ otb_test_application( ) #------------ SampleAgmentation TESTS ---------------- -otb_test_application(NAME apTvClSampleAugmentation +otb_test_application(NAME apTvClSampleAugmentationReplicate APP SampleAugmentation OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite -field class -label 3 -samples 100 - -out ${TEMP}/apTvClSampleAugmentation.sqlite + -out ${TEMP}/apTvClSampleAugmentationReplicate.sqlite -exclude originfid + -strategy replicate + ) + +otb_test_application(NAME apTvClSampleAugmentationJitter + APP SampleAugmentation + OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite + -field class + -label 3 + -samples 100 + -out ${TEMP}/apTvClSampleAugmentationJitter.sqlite + -exclude originfid + -strategy jitter + -strategy.jitter.stdfactor 10000 + ) + +otb_test_application(NAME apTvClSampleAugmentationSmote + APP SampleAugmentation + OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite + -field class + -label 3 + -samples 100 + -out ${TEMP}/apTvClSampleAugmentationSmote.sqlite + -exclude originfid + -strategy smote + -strategy.smote.neighbors 5 ) -- GitLab