diff --git a/Modules/Adapters/GdalAdapters/src/otbOGRDataSourceWrapper.cxx b/Modules/Adapters/GdalAdapters/src/otbOGRDataSourceWrapper.cxx index d5cfec4ef94e8efe1c21c105921974236dc7893a..b8b5d4e94096611d5c8d42378157eea8332fec3b 100644 --- a/Modules/Adapters/GdalAdapters/src/otbOGRDataSourceWrapper.cxx +++ b/Modules/Adapters/GdalAdapters/src/otbOGRDataSourceWrapper.cxx @@ -86,6 +86,7 @@ const ExtensionDriverAssociation k_ExtensionDriverMap[] = {".GPX", "GPX"}, {".SQLITE", "SQLite"}, {".KML", "KML"}, + {".CSV", "CSV"}, }; /**\ingroup GeometryInternals * \brief Returns the OGR driver name associated to a filename. diff --git a/Modules/Applications/AppClassification/app/CMakeLists.txt b/Modules/Applications/AppClassification/app/CMakeLists.txt index 3e1dbd85f5cd2ca409ddb9ef51389edf5e544976..d34c3842d0a6d9e39011b98df48af97bfb2ceb87 100644 --- a/Modules/Applications/AppClassification/app/CMakeLists.txt +++ b/Modules/Applications/AppClassification/app/CMakeLists.txt @@ -125,5 +125,10 @@ otb_create_application( SOURCES otbVectorClassifier.cxx LINK_LIBRARIES ${${otb-module}_LIBRARIES}) +otb_create_application( + NAME SampleAugmentation + SOURCES otbSampleAugmentation.cxx + LINK_LIBRARIES ${${otb-module}_LIBRARIES}) + # Mantis-1427 : temporary fix add_dependencies(${otb-module}-all otbapp_ImageEnvelope) diff --git a/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx new file mode 100644 index 0000000000000000000000000000000000000000..6bb7382856e619cce13f501acea90530cfd18f37 --- /dev/null +++ b/Modules/Applications/AppClassification/app/otbSampleAugmentation.cxx @@ -0,0 +1,270 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "otbWrapperApplication.h" +#include "otbWrapperApplicationFactory.h" +#include "otbOGRDataSourceWrapper.h" +#include "otbSampleAugmentationFilter.h" + +namespace otb +{ +namespace Wrapper +{ + + +class SampleAugmentation : public Application +{ +public: + /** Standard class typedefs. */ + typedef SampleAugmentation Self; + typedef Application Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Standard macro */ + itkNewMacro(Self); + + itkTypeMacro(SampleAugmentation, otb::Application); + + /** Filters typedef */ + using FilterType = otb::SampleAugmentationFilter; + using SampleType = FilterType::SampleType; + using SampleVectorType = FilterType::SampleVectorType; + +private: + SampleAugmentation() {} + + void DoInit() + { + SetName("SampleAugmentation"); + SetDescription("Generates synthetic samples from a sample data file."); + + // Documentation + SetDocName("Sample Augmentation"); + SetDocLongDescription("The application takes a sample data file as " + "generated by the SampleExtraction application and " + "generates synthetic samples to increase the number of " + "available samples."); + SetDocLimitations("None"); + SetDocAuthors("OTB-Team"); + SetDocSeeAlso(" "); + + AddDocTag(Tags::Learning); + + AddParameter(ParameterType_InputFilename, "in", "Input samples"); + SetParameterDescription("in","Vector data file containing samples (OGR format)"); + + AddParameter(ParameterType_OutputFilename, "out", "Output samples"); + SetParameterDescription("out","Output vector data file storing new samples" + "(OGR format)."); + + AddParameter(ParameterType_ListView, "field", "Field Name"); + SetParameterDescription("field","Name of the field carrying the class name in the input vectors."); + SetListViewSingleSelectionMode("field",true); + + AddParameter(ParameterType_Int, "layer", "Layer Index"); + SetParameterDescription("layer", "Layer index to read in the input vector file."); + MandatoryOff("layer"); + SetDefaultParameterInt("layer",0); + + AddParameter(ParameterType_Int, "label", "Label of the class to be augmented"); + SetParameterDescription("label", "Label of the class of the input file for which " + "new samples will be generated."); + SetDefaultParameterInt("label",1); + + AddParameter(ParameterType_Int, "samples", "Number of generated samples"); + SetParameterDescription("samples", "Number of synthetic samples that will " + "be generated."); + SetDefaultParameterInt("samples",100); + + AddParameter(ParameterType_ListView, "exclude", "Field names for excluded features."); + SetParameterDescription("exclude", + "List of field names in the input vector data that will not be generated in the output file."); + + AddParameter(ParameterType_Choice, "strategy", "Augmentation strategy"); + + AddChoice("strategy.replicate","Replicate input samples"); + SetParameterDescription("strategy.replicate","The new samples are generated " + "by replicating input samples which are randomly " + "selected with replacement."); + + AddChoice("strategy.jitter","Jitter input samples"); + SetParameterDescription("strategy.jitter","The new samples are generated " + "by adding gaussian noise to input samples which are " + "randomly selected with replacement."); + AddParameter(ParameterType_Float, "strategy.jitter.stdfactor", + "Factor for dividing the standard deviation of each feature"); + SetParameterDescription("strategy.jitter.stdfactor", + "The noise added to the input samples will have the " + "standard deviation of the input features divided " + "by the value of this parameter. "); + SetDefaultParameterFloat("strategy.jitter.stdfactor",10); + + AddChoice("strategy.smote","Smote input samples"); + SetParameterDescription("strategy.smote","The new samples are generated " + "by using the SMOTE algorithm (http://dx.doi.org/10.1613/jair.953) " + "on input samples which are " + "randomly selected with replacement."); + AddParameter(ParameterType_Int, "strategy.smote.neighbors", + "Number of nearest neighbors."); + SetParameterDescription("strategy.smote.neighbors", + "Number of nearest neighbors to be used in the " + "SMOTE algorithm"); + SetDefaultParameterFloat("strategy.smote.neighbors", 5); + + AddRANDParameter("seed"); + MandatoryOff("seed"); + + // Doc example parameter settings + SetDocExampleParameterValue("in", "samples.sqlite"); + SetDocExampleParameterValue("field", "class"); + SetDocExampleParameterValue("label", "3"); + SetDocExampleParameterValue("samples", "100"); + SetDocExampleParameterValue("out","augmented_samples.sqlite"); + SetDocExampleParameterValue( "exclude", "OGC_FID name class originfid" ); + SetDocExampleParameterValue("strategy", "smote"); + SetDocExampleParameterValue("strategy.smote.neighbors", "5"); + + SetOfficialDocLink(); + } + + void DoUpdateParameters() + { + if ( HasValue("in") ) + { + std::string vectorFile = GetParameterString("in"); + ogr::DataSource::Pointer ogrDS = + ogr::DataSource::New(vectorFile, ogr::DataSource::Modes::Read); + ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer")); + ogr::Feature feature = layer.ogr().GetNextFeature(); + + ClearChoices("exclude"); + ClearChoices("field"); + + for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++) + { + std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef(); + key = item; + std::string::iterator end = std::remove_if(key.begin(),key.end(), + [](auto c){return !std::isalnum(c);}); + std::transform(key.begin(), end, key.begin(), tolower); + + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType(); + + if(fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType)) + { + std::string tmpKey="field."+key.substr(0, end - key.begin()); + AddChoice(tmpKey,item); + } + if( fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) || fieldType == OFTReal ) + { + std::string tmpKey = "exclude." + key.substr( 0, static_cast<unsigned long>( end - key.begin() ) ); + AddChoice( tmpKey, item ); + } + } + } + } + + void DoExecute() + { + ogr::DataSource::Pointer vectors; + ogr::DataSource::Pointer output; + vectors = ogr::DataSource::New(this->GetParameterString("in")); + output = ogr::DataSource::New(this->GetParameterString("out"), + ogr::DataSource::Modes::Overwrite); + + // Retrieve the field name + std::vector<int> selectedCFieldIdx = GetSelectedItems("field"); + + if(selectedCFieldIdx.empty()) + { + otbAppLogFATAL(<<"No field has been selected for data labelling!"); + } + + std::vector<std::string> cFieldNames = GetChoiceNames("field"); + std::string fieldName = cFieldNames[selectedCFieldIdx.front()]; + + std::vector<std::string> excludedFields = + GetExcludedFields( GetChoiceNames( "exclude" ), + GetSelectedItems( "exclude" )); + for(const auto& ef : excludedFields) + otbAppLogINFO("Excluding feature " << ef << '\n'); + + int seed = std::time(nullptr); + if(IsParameterEnabled("seed")) seed = this->GetParameterInt("seed"); + + + FilterType::Pointer filter = FilterType::New(); + filter->SetInput(vectors); + filter->SetLayer(this->GetParameterInt("layer")); + filter->SetNumberOfSamples(this->GetParameterInt("samples")); + filter->SetOutputSamples(output); + filter->SetClassFieldName(fieldName); + filter->SetLabel(this->GetParameterInt("label")); + filter->SetExcludedFields(excludedFields); + filter->SetSeed(seed); + switch (this->GetParameterInt("strategy")) + { + // replicate + case 0: + { + otbAppLogINFO("Augmentation strategy : replicate"); + filter->SetStrategy(FilterType::Strategy::Replicate); + } + break; + // jitter + case 1: + { + otbAppLogINFO("Augmentation strategy : jitter"); + filter->SetStrategy(FilterType::Strategy::Jitter); + filter->SetStdFactor(this->GetParameterFloat("strategy.jitter.stdfactor")); + } + break; + case 2: + { + otbAppLogINFO("Augmentation strategy : smote"); + filter->SetStrategy(FilterType::Strategy::Smote); + filter->SetSmoteNeighbors(this->GetParameterInt("strategy.smote.neighbors")); + } + break; + } + filter->Update(); + output->SyncToDisk(); + } + + + std::vector<std::string> GetExcludedFields(const std::vector<std::string>& fieldNames, + const std::vector<int>& selectedIdx) + { + auto nbFeatures = static_cast<unsigned int>(selectedIdx.size()); + std::vector<std::string> result( nbFeatures ); + for( unsigned int i = 0; i < nbFeatures; ++i ) + { + result[i] = fieldNames[selectedIdx[i]]; + } + return result; + } + +}; + +} // end of namespace Wrapper +} // end of namespace otb + +OTB_APPLICATION_EXPORT(otb::Wrapper::SampleAugmentation) diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index fae1474266db959a03a6accd01b14b43f00f4ec0..e848e08ed8fca6c2f1bd840c6adc16816b716bc6 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -972,3 +972,39 @@ otb_test_application( ${OTBAPP_BASELINE_FILES}/apTvClMultiImageSamplingRate_out_3.csv ${TEMP}/apTvClMultiImageSamplingRate_out_3.csv ) + +#------------ SampleAgmentation TESTS ---------------- +otb_test_application(NAME apTvClSampleAugmentationReplicate + APP SampleAugmentation + OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite + -field class + -label 3 + -samples 100 + -out ${TEMP}/apTvClSampleAugmentationReplicate.sqlite + -exclude originfid + -strategy replicate + ) + +otb_test_application(NAME apTvClSampleAugmentationJitter + APP SampleAugmentation + OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite + -field class + -label 3 + -samples 100 + -out ${TEMP}/apTvClSampleAugmentationJitter.sqlite + -exclude originfid + -strategy jitter + -strategy.jitter.stdfactor 10 + ) + +otb_test_application(NAME apTvClSampleAugmentationSmote + APP SampleAugmentation + OPTIONS -in ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite + -field class + -label 3 + -samples 100 + -out ${TEMP}/apTvClSampleAugmentationSmote.sqlite + -exclude originfid + -strategy smote + -strategy.smote.neighbors 5 + ) diff --git a/Modules/Learning/Sampling/include/otbSampleAugmentation.h b/Modules/Learning/Sampling/include/otbSampleAugmentation.h new file mode 100644 index 0000000000000000000000000000000000000000..84e67ffdd413940e2646ff2f35d6423241ab5f3f --- /dev/null +++ b/Modules/Learning/Sampling/include/otbSampleAugmentation.h @@ -0,0 +1,217 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef otbSampleAugmentation_h +#define otbSampleAugmentation_h + +#include <vector> +#include <algorithm> +#include <random> +#include <ctime> +#include <cassert> +#include <iostream> + +namespace otb +{ + +namespace sampleAugmentation +{ +using SampleType = std::vector<double>; +using SampleVectorType = std::vector<SampleType>; + +/** +Estimate standard deviations of the components in one pass using +Welford's algorithm +*/ +SampleType EstimateStds(const SampleVectorType& samples) +{ + const auto nbSamples = samples.size(); + const auto nbComponents = samples[0].size(); + SampleType stds(nbComponents, 0.0); + SampleType means(nbComponents, 0.0); + for(size_t i=0; i<nbSamples; ++i) + { + auto norm_factor = 1.0/(i+1); +#pragma omp parallel for + for(size_t j=0; j<nbComponents; ++j) + { + const auto mu = means[j]; + const auto x = samples[i][j]; + auto muNew = mu+(x-mu)*norm_factor; + stds[j] += (x-mu)*(x-muNew); + means[j] = muNew; + } + } +#pragma omp parallel for + for(size_t j=0; j<nbComponents; ++j) + { + stds[j] = std::sqrt(stds[j]/nbSamples); + } + return stds; +} + +/** Create new samples by replicating input samples. We loop through +* the input samples and add them to the new data set until nbSamples +* are added. The elements of newSamples are removed before proceeding. +*/ +void ReplicateSamples(const SampleVectorType& inSamples, + const size_t nbSamples, + SampleVectorType& newSamples) +{ + newSamples.resize(nbSamples); + size_t imod{0}; +#pragma omp parallel for + for(size_t i=0; i<nbSamples; ++i) + { + if (imod == inSamples.size()) imod = 0; + newSamples[i] = inSamples[imod++]; + } + +} + +/** Create new samples by adding noise to existing samples. Gaussian +* noise is added to randomly selected samples. The standard deviation +* of the noise added to each component is the same as the one of the +* input variables divided by stdFactor (defaults to 10). The +* elements of newSamples are removed before proceeding. +*/ +void JitterSamples(const SampleVectorType& inSamples, + const size_t nbSamples, + SampleVectorType& newSamples, + float stdFactor=10, + const int seed = std::time(nullptr)) +{ + newSamples.resize(nbSamples); + const auto nbComponents = inSamples[0].size(); + std::random_device rd; + std::mt19937 gen(rd()); + // The input samples are selected randomly with replacement + std::srand(seed); + // We use one gaussian distribution per component since they may + // have different stds + auto stds = EstimateStds(inSamples); + std::vector<std::normal_distribution<double>> gaussDis(nbComponents); +#pragma omp parallel for + for(size_t i=0; i<nbComponents; ++i) + gaussDis[i] = std::normal_distribution<double>{0.0, stds[i]/stdFactor}; + + for(size_t i=0; i<nbSamples; ++i) + { + newSamples[i] = inSamples[std::rand()%inSamples.size()]; +#pragma omp parallel for + for(size_t j=0; j<nbComponents; ++j) + newSamples[i][j] += gaussDis[j](gen); + } +} + + +struct NeighborType +{ + size_t index; + double distance; +}; + +struct NeighborSorter +{ + constexpr bool operator ()(const NeighborType& a, const NeighborType& b) const + { + return b.distance > a.distance; + } +}; + +double ComputeSquareDistance(const SampleType& x, const SampleType& y) +{ + assert(x.size()==y.size()); + double dist{0}; + for(size_t i=0; i<x.size(); ++i) + { + dist += (x[i]-y[i])*(x[i]-y[i]); + } + return dist/(x.size()*x.size()); +} + +using NNIndicesType = std::vector<NeighborType>; +using NNVectorType = std::vector<NNIndicesType>; +/** Returns the indices of the nearest neighbors for each input sample +*/ +void FindKNNIndices(const SampleVectorType& inSamples, + const size_t nbNeighbors, + NNVectorType& nnVector) +{ + const auto nbSamples = inSamples.size(); + nnVector.resize(nbSamples); + #pragma omp parallel for + for(size_t sampleIdx=0; sampleIdx<nbSamples; ++sampleIdx) + { + NNIndicesType nns; + for(size_t neighborIdx=0; neighborIdx<nbSamples; ++neighborIdx) + { + if(sampleIdx!=neighborIdx) + nns.push_back({neighborIdx, ComputeSquareDistance(inSamples[sampleIdx], + inSamples[neighborIdx])}); + } + std::partial_sort(nns.begin(), nns.begin()+nbNeighbors, nns.end(), NeighborSorter{}); + nns.resize(nbNeighbors); + nnVector[sampleIdx] = std::move(nns); + } +} + +/** Generate the new sample in the line linking s1 and s2 +*/ +SampleType SmoteCombine(const SampleType& s1, const SampleType& s2, double position) +{ + auto result = s1; + for(size_t i=0; i<s1.size(); ++i) + result[i] = s1[i]+(s2[i]-s1[i])*position; + return result; +} + +/** Create new samples using the SMOTE algorithm +Chawla, N. V., Bowyer, K. W., Hall, L. O., & Kegelmeyer, W. P., Smote: +synthetic minority over-sampling technique, Journal of artificial +intelligence research, 16(), 321–357 (2002). +http://dx.doi.org/10.1613/jair.953 +*/ +void Smote(const SampleVectorType& inSamples, + const size_t nbSamples, + SampleVectorType& newSamples, + const int nbNeighbors, + const int seed = std::time(nullptr)) +{ + newSamples.resize(nbSamples); + NNVectorType nnVector; + FindKNNIndices(inSamples, nbNeighbors, nnVector); + // The input samples are selected randomly with replacement + std::srand(seed); + #pragma omp parallel for + for(size_t i=0; i<nbSamples; ++i) + { + const auto sampleIdx = std::rand()%(inSamples.size()); + const auto sample = inSamples[sampleIdx]; + const auto neighborIdx = nnVector[sampleIdx][std::rand()%nbNeighbors].index; + const auto neighbor = inSamples[neighborIdx]; + newSamples[i] = SmoteCombine(sample, neighbor, std::rand()/double{RAND_MAX}); + } +} + +}//end namespaces sampleAugmentation +}//end namespace otb + +#endif diff --git a/Modules/Learning/Sampling/include/otbSampleAugmentationFilter.h b/Modules/Learning/Sampling/include/otbSampleAugmentationFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..99bf925f50ba533f987bad3d16f08170352655de --- /dev/null +++ b/Modules/Learning/Sampling/include/otbSampleAugmentationFilter.h @@ -0,0 +1,178 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef otbSampleAugmentationFilter_h +#define otbSampleAugmentationFilter_h + +#include "itkProcessObject.h" +#include "otbOGRDataSourceWrapper.h" +#include "otbSampleAugmentation.h" + +namespace otb +{ + + +/** + * \class SampleAugmentationFilter + * + * \brief Filter to generate synthetic samples from existing ones + * + * This class generates synthetic samples from existing ones either by + * replication, jitter (adding gaussian noise to the features of + * existing samples) or SMOTE (linear combination of pairs + * neighbouring samples of the same class. + * + * \ingroup OTBSampling + */ + +class ITK_EXPORT SampleAugmentationFilter : + public itk::ProcessObject +{ +public: + + /** typedef for the classes standards. */ + typedef SampleAugmentationFilter Self; + typedef itk::ProcessObject Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Method for management of the object factory. */ + itkNewMacro(Self); + + /** Return the name of the class. */ + itkTypeMacro(SampleAugmentationFilter, ProcessObject); + + typedef ogr::DataSource OGRDataSourceType; + typedef typename OGRDataSourceType::Pointer OGRDataSourcePointerType; + typedef ogr::Layer OGRLayerType; + + typedef itk::ProcessObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType; + + using SampleType = sampleAugmentation::SampleType; + using SampleVectorType = sampleAugmentation::SampleVectorType; + + enum class Strategy { Replicate, Jitter, Smote }; + + /** Set/Get the input OGRDataSource of this process object. */ + using Superclass::SetInput; + virtual void SetInput(const OGRDataSourceType* ds); + const OGRDataSourceType* GetInput(unsigned int idx); + + virtual void SetOutputSamples(ogr::DataSource* data); + + /** Set the Field Name in which labels will be written. (default is "class") + * A field "ClassFieldName" of type integer is created in the output memory layer. + */ + itkSetMacro(ClassFieldName, std::string); + /** + * Return the Field name in which labels have been written. + */ + itkGetMacro(ClassFieldName, std::string); + + + itkSetMacro(Layer, size_t); + itkGetMacro(Layer, size_t); + itkSetMacro(Label, int); + itkGetMacro(Label, int); + void SetStrategy(Strategy s) + { + m_Strategy = s; + } + Strategy GetStrategy() const + { + return m_Strategy; + } + itkSetMacro(NumberOfSamples, int); + itkGetMacro(NumberOfSamples, int); + void SetExcludedFields(const std::vector<std::string>& ef) + { + m_ExcludedFields = ef; + } + std::vector<std::string> GetExcludedFields() const + { + return m_ExcludedFields; + } + itkSetMacro(StdFactor, double); + itkGetMacro(StdFactor, double); + itkSetMacro(SmoteNeighbors, size_t); + itkGetMacro(SmoteNeighbors, size_t); + itkSetMacro(Seed, int); + itkGetMacro(Seed, int); +/** + * Get the output \c ogr::DataSource which is a "memory" datasource. + */ + const OGRDataSourceType * GetOutput(); + +protected: + SampleAugmentationFilter(); + ~SampleAugmentationFilter() ITK_OVERRIDE {} + + /** Generate Data method*/ + void GenerateData() ITK_OVERRIDE; + + /** DataObject pointer */ + typedef itk::DataObject::Pointer DataObjectPointer; + + DataObjectPointer MakeOutput(DataObjectPointerArraySizeType idx) ITK_OVERRIDE; + using Superclass::MakeOutput; + + + SampleVectorType ExtractSamples(const ogr::DataSource::Pointer vectors, + size_t layerName, + const std::string& classField, const int label, + const std::vector<std::string>& excludedFields = {}); + + void SampleToOGRFeatures(const ogr::DataSource::Pointer& vectors, + ogr::DataSource* output, + const SampleVectorType& samples, + const size_t layerName, + const std::string& classField, int label, + const std::vector<std::string>& excludedFields = {}); + + std::set<size_t> GetExcludedFieldsIds(const std::vector<std::string>& excludedFields, + const ogr::Layer& inputLayer); + bool IsNumericField(const ogr::Feature& feature, const int idx); + + ogr::Feature SelectTemplateFeature(const ogr::Layer& inputLayer, + const std::string& classField, int label); +private: + SampleAugmentationFilter(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + std::string m_ClassFieldName; + size_t m_Layer; + int m_Label; + std::vector<std::string> m_ExcludedFields; + Strategy m_Strategy; + int m_NumberOfSamples; + double m_StdFactor; + size_t m_SmoteNeighbors; + int m_Seed; + +}; + + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbSampleAugmentationFilter.txx" +#endif + +#endif diff --git a/Modules/Learning/Sampling/include/otbSampleAugmentationFilter.txx b/Modules/Learning/Sampling/include/otbSampleAugmentationFilter.txx new file mode 100644 index 0000000000000000000000000000000000000000..8976c1c55b82bbd4663d0ea00df49ddd4aedc4f2 --- /dev/null +++ b/Modules/Learning/Sampling/include/otbSampleAugmentationFilter.txx @@ -0,0 +1,273 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef otbSampleAugmentationFilter_txx +#define otbSampleAugmentationFilter_txx + +#include "otbSampleAugmentationFilter.h" +#include "stdint.h" //needed for uintptr_t + +namespace otb +{ + +SampleAugmentationFilter +::SampleAugmentationFilter() : m_ClassFieldName{"class"}, m_Layer{0}, m_Label{1}, + m_Strategy{SampleAugmentationFilter::Strategy::Replicate}, + m_NumberOfSamples{100}, m_StdFactor{10.0}, + m_SmoteNeighbors{5}, m_Seed{0} +{ + this->SetNumberOfRequiredInputs(1); + this->SetNumberOfRequiredOutputs(1); + this->ProcessObject::SetNthOutput(0, this->MakeOutput(0) ); +} + + +typename SampleAugmentationFilter::DataObjectPointer +SampleAugmentationFilter +::MakeOutput(DataObjectPointerArraySizeType itkNotUsed(idx)) +{ + return static_cast< DataObjectPointer >(OGRDataSourceType::New().GetPointer()); +} + +const typename SampleAugmentationFilter::OGRDataSourceType * +SampleAugmentationFilter +::GetOutput() +{ + return static_cast< const OGRDataSourceType *>( + this->ProcessObject::GetOutput(0)); +} + +void +SampleAugmentationFilter +::SetInput(const otb::ogr::DataSource* ds) +{ + this->Superclass::SetNthInput(0, const_cast<otb::ogr::DataSource *>(ds)); +} + +const typename SampleAugmentationFilter::OGRDataSourceType * +SampleAugmentationFilter +::GetInput(unsigned int idx) +{ + return static_cast<const OGRDataSourceType *> + (this->itk::ProcessObject::GetInput(idx)); +} + +void +SampleAugmentationFilter +::SetOutputSamples(ogr::DataSource* data) +{ + this->SetNthOutput(0,data); +} + + +void +SampleAugmentationFilter +::GenerateData(void) +{ + + OGRDataSourcePointerType inputDS = dynamic_cast<OGRDataSourceType*>(this->itk::ProcessObject::GetInput(0)); + auto outputDS = static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(0)); + auto inSamples = this->ExtractSamples(inputDS, m_Layer, + m_ClassFieldName, + m_Label, + m_ExcludedFields); + SampleVectorType newSamples; + switch (m_Strategy) + { + case Strategy::Replicate: + { + sampleAugmentation::ReplicateSamples(inSamples, m_NumberOfSamples, + newSamples); + } + break; + case Strategy::Jitter: + { + sampleAugmentation::JitterSamples(inSamples, m_NumberOfSamples, + newSamples, + m_StdFactor, + m_Seed); + } + break; + case Strategy::Smote: + { + sampleAugmentation::Smote(inSamples, m_NumberOfSamples, + newSamples, + m_SmoteNeighbors, + m_Seed); + } + break; + } + this->SampleToOGRFeatures(inputDS, outputDS, newSamples, m_Layer, + m_ClassFieldName, + m_Label, + m_ExcludedFields); + + + // this->SetNthOutput(0,outputDS); +} + +/** Extracts the samples of a single class from the vector data to a +* vector and excludes some unwanted features. +*/ +SampleAugmentationFilter::SampleVectorType +SampleAugmentationFilter +::ExtractSamples(const ogr::DataSource::Pointer vectors, + size_t layerName, + const std::string& classField, const int label, + const std::vector<std::string>& excludedFields) +{ + ogr::Layer layer = vectors->GetLayer(layerName); + auto featureIt = layer.begin(); + if(featureIt==layer.end()) + { + itkExceptionMacro("Layer " << layerName << " of input sample file is empty.\n"); + } + int cFieldIndex = (*featureIt).ogr().GetFieldIndex( classField.c_str() ); + if( cFieldIndex < 0 ) + { + itkExceptionMacro( "The field name for class label (" << classField + << ") has not been found in the vector file " ); + } + + auto numberOfFields = (*featureIt).ogr().GetFieldCount(); + auto excludedIds = this->GetExcludedFieldsIds(excludedFields, layer); + SampleVectorType samples; + int sampleCount{0}; + while( featureIt!=layer.end() ) + { + // Retrieve all the features for each field in the ogr layer. + if((*featureIt).ogr().GetFieldAsInteger(classField.c_str()) == label) + { + + SampleType mv; + for(auto idx=0; idx<numberOfFields; ++idx) + { + if(excludedIds.find(idx) == excludedIds.cend() && + this->IsNumericField((*featureIt), idx)) + mv.push_back((*featureIt).ogr().GetFieldAsDouble(idx)); + } + samples.push_back(mv); + ++sampleCount; + } + ++featureIt; + } + if(sampleCount==0) + { + itkExceptionMacro("Could not find any samples in layer " << layerName << + " with label " << label << '\n'); + } + return samples; +} + +void +SampleAugmentationFilter +::SampleToOGRFeatures(const ogr::DataSource::Pointer& vectors, + ogr::DataSource* output, + const SampleAugmentationFilter::SampleVectorType& samples, + const size_t layerName, + const std::string& classField, int label, + const std::vector<std::string>& excludedFields) +{ + + auto inputLayer = vectors->GetLayer(layerName); + auto excludedIds = this->GetExcludedFieldsIds(excludedFields, inputLayer); + + OGRSpatialReference * oSRS = nullptr; + if (inputLayer.GetSpatialRef()) + { + oSRS = inputLayer.GetSpatialRef()->Clone(); + } + OGRFeatureDefn &layerDefn = inputLayer.GetLayerDefn(); + + auto outputLayer = output->CreateLayer(inputLayer.GetName(), oSRS, + inputLayer.GetGeomType()); + for (int k=0 ; k < layerDefn.GetFieldCount() ; k++) + { + OGRFieldDefn originDefn(layerDefn.GetFieldDefn(k)); + ogr::FieldDefn fieldDefn(originDefn); + outputLayer.CreateField(fieldDefn); + } + + auto featureCount = outputLayer.GetFeatureCount(false); + auto templateFeature = this->SelectTemplateFeature(inputLayer, classField, label); + for(const auto& sample : samples) + { + ogr::Feature dstFeature(outputLayer.GetLayerDefn()); + dstFeature.SetFrom( templateFeature, TRUE ); + dstFeature.SetFID(++featureCount); + auto sampleFieldCounter = 0; + for (int k=0 ; k < layerDefn.GetFieldCount() ; k++) + { + if(excludedIds.find(k) == excludedIds.cend() && + this->IsNumericField(dstFeature, k)) + { + dstFeature.ogr().SetField(k, sample[sampleFieldCounter++]); + } + } + outputLayer.CreateFeature( dstFeature ); + } +} + +std::set<size_t> +SampleAugmentationFilter +::GetExcludedFieldsIds(const std::vector<std::string>& excludedFields, + const ogr::Layer& inputLayer) +{ + auto feature = *(inputLayer).begin(); + std::set<size_t> excludedIds; + if( excludedFields.size() != 0) + { + for(const auto& fieldName : excludedFields) + { + auto idx = feature.ogr().GetFieldIndex( fieldName.c_str() ); + excludedIds.insert(idx); + } + } + return excludedIds; +} + +bool +SampleAugmentationFilter +::IsNumericField(const ogr::Feature& feature, + const int idx) +{ + OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(idx)->GetType(); + return (fieldType == OFTInteger + || ogr::version_proxy::IsOFTInteger64( fieldType ) + || fieldType == OFTReal); +} + +ogr::Feature +SampleAugmentationFilter +::SelectTemplateFeature(const ogr::Layer& inputLayer, + const std::string& classField, int label) +{ + auto wh = std::find_if(inputLayer.begin(), inputLayer.end(), + [&](auto& featureIt) + { + return featureIt.ogr().GetFieldAsInteger(classField.c_str()) == + label; + }); + return (wh == inputLayer.end()) ? *(inputLayer.begin()) : *wh; + +} +} // end namespace otb + +#endif