diff --git a/Modules/Applications/AppClassification/app/otbSampleSelection.cxx b/Modules/Applications/AppClassification/app/otbSampleSelection.cxx index eba28624b32752750a8f748353593b8b5673c60c..bf208d8bc430ec573ea39f3b5bd40b9a3e765606 100644 --- a/Modules/Applications/AppClassification/app/otbSampleSelection.cxx +++ b/Modules/Applications/AppClassification/app/otbSampleSelection.cxx @@ -166,6 +166,15 @@ private: AddParameter(ParameterType_Int, "strategy.constant.nb", "Number of samples for all classes"); SetParameterDescription("strategy.constant.nb", "Number of samples for all classes"); + AddChoice("strategy.percent","Use a percentage of the samples available for each class"); + SetParameterDescription("strategy.percent","Use a percentage of the samples available for each class"); + + AddParameter(ParameterType_Float,"strategy.percent.p","The percentage to use"); + SetParameterDescription("strategy.percent.p","The percentage to use"); + SetMinimumParameterFloatValue("strategy.percent.p",0); + SetMaximumParameterFloatValue("strategy.percent.p",1); + SetDefaultParameterFloat("strategy.percent.p",0.5); + AddChoice("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled"); SetParameterDescription("strategy.smallest","Set same number of samples for all classes, with the smallest class fully sampled"); @@ -234,15 +243,22 @@ private: m_RateCalculator->SetNbOfSamplesAllClasses(GetParameterInt("strategy.constant.nb")); } break; - // smallest class + // percent case 2: + { + otbAppLogINFO("Sampluing strategy: set a percentage of samples for each class."); + m_RateCalculator->SetPercentageOfSamples(this->GetParameterFloat("strategy.percent.p")); + } + break; + // smallest class + case 3: { otbAppLogINFO("Sampling strategy : fit the number of samples based on the smallest class"); m_RateCalculator->SetMinimumNbOfSamplesByClass(); } break; // all samples - case 3: + case 4: { otbAppLogINFO("Sampling strategy : take all samples"); m_RateCalculator->SetAllSamples(); diff --git a/Modules/Learning/Sampling/include/otbSamplingRateCalculator.h b/Modules/Learning/Sampling/include/otbSamplingRateCalculator.h index 8743a95553695f3f3e5bff34a1e5ab0378f30bbf..692e7f176686d024356e97b9b0ba454cc5cd399f 100644 --- a/Modules/Learning/Sampling/include/otbSamplingRateCalculator.h +++ b/Modules/Learning/Sampling/include/otbSamplingRateCalculator.h @@ -68,6 +68,9 @@ public: /** Method to set the same number of required samples in each class */ void SetNbOfSamplesAllClasses(unsigned long); + /** Method to set a percentage of samples for each class */ + void SetPercentageOfSamples(double percent); + /** Method to choose a sampling strategy based on the smallest class. * The number of samples in each class is set to this minimum size*/ void SetMinimumNbOfSamplesByClass(void); diff --git a/Modules/Learning/Sampling/src/otbSamplingRateCalculator.cxx b/Modules/Learning/Sampling/src/otbSamplingRateCalculator.cxx index 65a85589b9450b7d33d3c801d5a03ae3816369c5..050965bab7dc488e866b3397197b157afac5eaba 100644 --- a/Modules/Learning/Sampling/src/otbSamplingRateCalculator.cxx +++ b/Modules/Learning/Sampling/src/otbSamplingRateCalculator.cxx @@ -110,6 +110,19 @@ SamplingRateCalculator } } + +void SamplingRateCalculator +::SetPercentageOfSamples(double percent) +{ + MapRateType::iterator it = m_RatesByClass.begin(); + for (; it != m_RatesByClass.end() ; ++it) + { + it->second.Required = static_cast<unsigned long>(vcl_floor(0.5+percent * it->second.Tot)); + it->second.Rate = percent; + } + +} + void SamplingRateCalculator ::Write(std::string filename)