diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx
index d17462be96ecd5b0a6d46e034716330915713beb..b88b73c90860c1f8d5a0f23cb916a23e6d54d3c8 100644
--- a/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx
@@ -160,6 +160,10 @@ private:
     SetParameterDescription("sample.nv", "Number of validation samples.");
     MandatoryOff("sample.nv");
 
+    AddParameter(ParameterType_Float, "sample.ratio", "Training and validation sample ratio");
+    SetParameterDescription("sample.ratio", "Ratio between training and validation samples.");
+    SetDefaultParameterFloat("sample.ratio", 0.5);
+
     ShareParameter( "rand", "select.rand" );
     
     ShareParameter( "ram", "polystat.ram" );
@@ -356,55 +360,65 @@ private:
     }
   }
   
-  void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName,
-                                                          std::string sampleTrainFileName,
-                                                          std::string sampleValidFileName,
-                                                          std::string ratesTrainFileName)
+  void SplitTrainingAndValidationSamples(const std::string & inputSampleFilePrefix)
   {
-    // Split between training and validation
-    ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read );
-    ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite );
-    ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite );
-    // read sampling rates from ratesTrainOutputs
-    SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
-    rateCalculator->Read( ratesTrainFileName );
-    // Compute sampling rates for train and valid
-    const otb::SamplingRateCalculator::MapRateType &inputRates = rateCalculator->GetRatesByClass();
-    otb::SamplingRateCalculator::MapRateType trainRates;
-    otb::SamplingRateCalculator::MapRateType validRates;
-    otb::SamplingRateCalculator::TripletType tpt;
-    for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
+    auto ImageList =  GetParameterImageList("io.il");
+    const auto& inputSampleFiles = m_FileHandler[inputSampleFilePrefix + "samples"];
+    auto& trainSampleFiles = m_FileHandler["trainsamples"];
+    auto& validSampleFiles = m_FileHandler["validsamples"];
+    const auto&          rateFiles         = m_FileHandler[inputSampleFilePrefix + "rateFiles"];
+    
+    for (unsigned int i = 0; i < ImageList->Size(); i++)
+    {
+      trainSampleFiles.push_back(GetParameterString("io.out") + "_trainsamples" + std::to_string(i) + ".shp");
+      validSampleFiles.push_back(GetParameterString("io.out") + "_validsamples" + std::to_string(i) + ".shp");
+      
+      // Split between training and validation
+      auto image = ImageList->GetNthElement(i);
+      ogr::DataSource::Pointer source = ogr::DataSource::New( inputSampleFiles[i], ogr::DataSource::Modes::Read );
+      ogr::DataSource::Pointer destTrain = ogr::DataSource::New( trainSampleFiles[i], ogr::DataSource::Modes::Overwrite );
+      ogr::DataSource::Pointer destValid = ogr::DataSource::New( validSampleFiles[i], ogr::DataSource::Modes::Overwrite );
+      // read sampling rates from ratesTrainOutputs
+      SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
+      rateCalculator->Read( rateFiles[i] );
+      // Compute sampling rates for train and valid
+      const otb::SamplingRateCalculator::MapRateType &inputRates = rateCalculator->GetRatesByClass();
+      otb::SamplingRateCalculator::MapRateType trainRates;
+      otb::SamplingRateCalculator::MapRateType validRates;
+      otb::SamplingRateCalculator::TripletType tpt;
+      for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
       {
-      double vtr = GetParameterFloat( "sample.vtr" );
-      unsigned long total = std::min( it->second.Required, it->second.Tot );
-      unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
-      unsigned long neededTrain = total - neededValid;
-      tpt.Tot = total;
-      tpt.Required = neededTrain;
-      tpt.Rate = ( 1.0 - vtr );
-      trainRates[it->first] = tpt;
-      tpt.Tot = neededValid;
-      tpt.Required = neededValid;
-      tpt.Rate = 1.0;
-      validRates[it->first] = tpt;
+        double vtr = GetParameterFloat( "sample.ratio" );
+        unsigned long total = std::min( it->second.Required, it->second.Tot );
+        unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
+        unsigned long neededTrain = total - neededValid;
+        tpt.Tot = total;
+        tpt.Required = neededTrain;
+        tpt.Rate = ( 1.0 - vtr );
+        trainRates[it->first] = tpt;
+        tpt.Tot = neededValid;
+        tpt.Required = neededValid;
+        tpt.Rate = 1.0;
+        validRates[it->first] = tpt;
       }
 
-    // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
-    PeriodicSamplerType::SamplerParameterType param;
-    param.Offset = 0;
-    param.MaxJitter = 0;
-    PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
-    splitter->SetInput( image );
-    splitter->SetOGRData( source );
-    splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
-    splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
-    splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] );
-    splitter->SetLayerIndex( 0 );
-    splitter->SetOriginFieldName( std::string( "" ) );
-    splitter->SetSamplerParameters( param );
-    splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
-    AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
-    splitter->Update();
+      // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
+      PeriodicSamplerType::SamplerParameterType param;
+      param.Offset = 0;
+      param.MaxJitter = 0;
+      PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
+      splitter->SetInput( image );
+      splitter->SetOGRData( source );
+      splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
+      splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
+      splitter->SetFieldName( m_ClassFieldName );
+      splitter->SetLayerIndex( 0 );
+      splitter->SetOriginFieldName( std::string( "" ) );
+      splitter->SetSamplerParameters( param );
+      splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
+      AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
+      splitter->Update();
+    }
   }
 
   /** Configure and execute TrainVectorClassifier. Note that many parameters of TrainVectorClassifier
@@ -419,13 +433,13 @@ private:
     {
       featureNames.push_back(m_FeaturePrefix + std::to_string(i));
     }
-
+    
     trainVectorRegression->SetParameterStringList("io.vd", trainSampleFileNameList);
     trainVectorRegression->UpdateParameters();
     trainVectorRegression->SetParameterString("cfield", m_PredictionFieldName);
     trainVectorRegression->SetParameterStringList("feat", featureNames);
 
-    if (IsParameterEnabled("io.valid") && HasValue("io.valid"))
+    if ((IsParameterEnabled("io.valid") && HasValue("io.valid")) || GetParameterFloat( "sample.ratio" ) >0)
     {
       trainVectorRegression->SetParameterStringList("valid.vd", m_FileHandler["validsamples"]);
     }
@@ -484,7 +498,7 @@ private:
   void DoExecute() override
   {
     SamplingParameters trainParams;
-    trainParams.filePrefix      = "train";
+    trainParams.filePrefix      = "vd";
     
     if (HasValue("sample.nt"))
       trainParams.numberOfSamples = GetParameterInt("sample.nt");
@@ -502,17 +516,29 @@ private:
     
     PerformSampling(trainParams);
 
+    // User validation data
     if (IsParameterEnabled("io.valid") && HasValue("io.valid"))
     {
+      m_FileHandler["trainsamples"] = m_FileHandler[trainParams.filePrefix + "samples"];
+      
       SamplingParameters validParams;
       validParams.inputVectorList = GetParameterStringList("io.valid");
       validParams.filePrefix      = "valid";
       if (HasValue("sample.nv"))
         validParams.numberOfSamples = GetParameterInt("sample.nv");
-
+      
       PerformSampling(validParams);
     }
-
+    // Split train and validation data
+    else if (GetParameterFloat( "sample.ratio" ) >0)
+    {
+      SplitTrainingAndValidationSamples( trainParams.filePrefix);
+    }
+    else
+    {
+      m_FileHandler["trainsamples"] = m_FileHandler[trainParams.filePrefix + "samples"];
+    }
+    
     otbAppLogINFO("Sampling Done.");
 
     TrainModel();