diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx b/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx index 243086d2de0a4039b4ce8b04e42cc1168890bc41..deb34bfc735224c66b64ac8a4d3d15d0b98ed7b3 100644 --- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx +++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx @@ -196,17 +196,20 @@ TrainImagesBase::SamplingRates TrainImagesBase::ComputeFinalMaximumSamplingRates // only fmt will be used for both training and validation samples // So we try to compute the total number of samples given input // parameters mt, mv and vtr. - if( mt > -1 && mv > -1 ) - { - rates.fmt = mt + mv; - } - if( mt > -1 && mv <= -1 && vtr < 0.99999 ) + if( mt > -1 && vtr < 0.99999 ) { rates.fmt = static_cast<long>(( double ) mt / ( 1.0 - vtr )); } - if( mt <= -1 && mv > -1 && vtr > 0.00001 ) + if( mv > -1 && vtr > 0.00001 ) { - rates.fmt = static_cast<long>(( double ) mv / vtr); + if( rates.fmt > -1 ) + { + rates.fmt = std::min( rates.fmt, static_cast<long>(( double ) mv / vtr) ); + } + else + { + rates.fmt = static_cast<long>(( double ) mv / vtr); + } } } } @@ -228,8 +231,10 @@ void TrainImagesBase::ComputeSamplingRate(const std::vector<std::string> &statis { if( maximum > -1 ) { + std::ostringstream oss; + oss << maximum; GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false ); - GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(maximum), false ); + GetInternalApplication( "rates" )->SetParameterString( "strategy.constant.nb", oss.str(), false ); } else {