diff --git a/Modules/Filtering/Statistics/include/otbListSampleGenerator.txx b/Modules/Filtering/Statistics/include/otbListSampleGenerator.txx index cc980403131a91298cef2727853f6cd75d9ba124..85fcebc15ecc6fd7d35979d38ec15fd41d19048c 100644 --- a/Modules/Filtering/Statistics/include/otbListSampleGenerator.txx +++ b/Modules/Filtering/Statistics/include/otbListSampleGenerator.txx @@ -400,23 +400,30 @@ ListSampleGenerator itmap != m_ClassesSize.end(); ++itmap) { - m_ClassesProbTraining[itmap->first] = minSizeTraining / itmap->second; - m_ClassesProbValidation[itmap->first] = minSizeValidation / itmap->second; - if(!m_BoundByMin) + if (m_BoundByMin) + { + m_ClassesProbTraining[itmap->first] = minSizeTraining / itmap->second; + m_ClassesProbValidation[itmap->first] = minSizeValidation / itmap->second; + } + else { long int maxSizeT = (itmap->second)*(1.0 - m_ValidationTrainingProportion); long int maxSizeV = (itmap->second)*m_ValidationTrainingProportion; - maxSizeT = (m_MaxTrainingSize == -1)?maxSizeT:m_MaxTrainingSize; - maxSizeV = (m_MaxValidationSize == -1)?maxSizeV:m_MaxValidationSize; - - //not enough samples to respect the bounds - if(maxSizeT+maxSizeV > itmap->second) + + // Check if max sizes respect the maximum bounds + double correctionRatioTrain = 1.0; + if((m_MaxTrainingSize > -1) && (m_MaxTrainingSize < maxSizeT)) + { + correctionRatioTrain = (double)(m_MaxTrainingSize) / (double)(maxSizeT); + } + double correctionRatioValid = 1.0; + if((m_MaxValidationSize > -1) && (m_MaxValidationSize < maxSizeV)) { - maxSizeT = (itmap->second)*(1.0 - m_ValidationTrainingProportion); - maxSizeV = (itmap->second)*m_ValidationTrainingProportion; + correctionRatioValid = (double)(m_MaxValidationSize) / (double)(maxSizeV); } - m_ClassesProbTraining[itmap->first] = maxSizeT/(itmap->second); - m_ClassesProbValidation[itmap->first] = maxSizeV/(itmap->second); + double correctionRatio = std::min(correctionRatioTrain,correctionRatioValid); + m_ClassesProbTraining[itmap->first] = correctionRatio*(1.0 - m_ValidationTrainingProportion); + m_ClassesProbValidation[itmap->first] = correctionRatio*m_ValidationTrainingProportion; } } }