From 5cfcbd5784847351a937d19165053a007be6acae Mon Sep 17 00:00:00 2001
From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr>
Date: Fri, 17 Feb 2017 14:34:23 +0100
Subject: [PATCH] ENH: Add TrainImagesBase to be either Supervised or
 Unsupervised.

Create a common class with template to perform Supervised classification
(same as otbTrainImagesClassifier) or Unsupervised.
---
 .../include/otbTrainImagesBase.h              | 693 ++++++++++++++++++
 1 file changed, 693 insertions(+)
 create mode 100644 Modules/Applications/AppClassification/include/otbTrainImagesBase.h

diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
new file mode 100644
index 0000000000..5b6aca1460
--- /dev/null
+++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
@@ -0,0 +1,693 @@
+/*=========================================================================
+ Program:   ORFEO Toolbox
+ Language:  C++
+ Date:      $Date$
+ Version:   $Revision$
+
+
+ Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+ See OTBCopyright.txt for details.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE.  See the above copyright notices for more information.
+
+ =========================================================================*/
+#ifndef otbTrainImagesBase_h
+#define otbTrainImagesBase_h
+
+#include "otbWrapperCompositeApplication.h"
+#include "otbWrapperApplicationFactory.h"
+
+#include "otbOGRDataToSamplePositionFilter.h"
+#include "otbSamplingRateCalculator.h"
+
+namespace otb
+{
+namespace Wrapper
+{
+
+template<bool IsSupervised = true>
+class TrainImagesBase : public CompositeApplication
+{
+public:
+  /** Standard class typedefs. */
+  typedef TrainImagesBase Self;
+  typedef CompositeApplication Superclass;
+  typedef itk::SmartPointer<Self> Pointer;
+  typedef itk::SmartPointer<const Self> ConstPointer;
+
+  /** Standard macro */
+  itkTypeMacro( TrainImagesBase, Superclass )
+
+  /** filters typedefs*/
+  typedef otb::OGRDataToSamplePositionFilter<FloatVectorImageType, UInt8ImageType, otb::PeriodicSampler> PeriodicSamplerType;
+
+  typedef otb::SamplingRateCalculator::MapRateType MapRateType;
+
+protected:
+
+private:
+  struct SamplingRates;
+
+  class TrainFileNamesHandler;
+
+  void InitSampling()
+  {
+    AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" );
+    AddApplication( "MultiImageSamplingRate", "rates", "Sampling rates" );
+    AddApplication( "SampleSelection", "select", "Sample selection" );
+    AddApplication( "SampleExtraction", "extraction", "Sample extraction" );
+
+    // Sampling settings
+    AddParameter( ParameterType_Group, "sample", "Training and validation samples parameters" );
+    SetParameterDescription( "sample",
+                             "This group of parameters allows you to set training and validation sample lists parameters." );
+    AddParameter( ParameterType_Int, "sample.mt", "Maximum training sample size per class" );
+    SetDefaultParameterInt( "sample.mt", 1000 );
+    SetParameterDescription( "sample.mt", "Maximum size per class (in pixels) of "
+            "the training sample list (default = 1000) (no limit = -1). If equal to -1,"
+            " then the maximal size of the available training sample list per class "
+            "will be equal to the surface area of the smallest class multiplied by the"
+            " training sample ratio." );
+    AddParameter( ParameterType_Int, "sample.mv", "Maximum validation sample size per class" );
+    SetDefaultParameterInt( "sample.mv", 1000 );
+    SetParameterDescription( "sample.mv", "Maximum size per class (in pixels) of "
+            "the validation sample list (default = 1000) (no limit = -1). If equal to -1,"
+            " then the maximal size of the available validation sample list per class "
+            "will be equal to the surface area of the smallest class multiplied by the "
+            "validation sample ratio." );
+    AddParameter( ParameterType_Int, "sample.bm", "Bound sample number by minimum" );
+    SetDefaultParameterInt( "sample.bm", 1 );
+    SetParameterDescription( "sample.bm", "Bound the number of samples for each "
+            "class by the number of available samples by the smaller class. Proportions "
+            "between training and validation are respected. Default is true (=1)." );
+    AddParameter( ParameterType_Float, "sample.vtr", "Training and validation sample ratio" );
+    SetParameterDescription( "sample.vtr", "Ratio between training and validation samples (0.0 = all training, 1.0 = "
+            "all validation) (default = 0.5)." );
+    SetParameterFloat( "sample.vtr", 0.5, false );
+    SetMaximumParameterFloatValue( "sample.vtr", 1.0 );
+    SetMinimumParameterFloatValue( "sample.vtr", 0.0 );
+
+    ShareSamplingParameters();
+    ConnectSamplingParameters();
+  }
+
+  void ShareSamplingParameters()
+  {
+    // hide sampling parameters
+    //ShareParameter("sample.strategy","rates.strategy");
+    //ShareParameter("sample.mim","rates.mim");
+    ShareParameter( "ram", "polystat.ram" );
+    ShareParameter( "elev", "polystat.elev" );
+    ShareParameter( "sample.vfn", "polystat.field" );
+  }
+
+  void ConnectSamplingParameters()
+  {
+    Connect( "extraction.field", "polystat.field" );
+    Connect( "extraction.layer", "polystat.layer" );
+
+    Connect( "select.ram", "polystat.ram" );
+    Connect( "extraction.ram", "polystat.ram" );
+
+    Connect( "select.field", "polystat.field" );
+    Connect( "select.layer", "polystat.layer" );
+    Connect( "select.elev", "polystat.elev" );
+
+    Connect( "extraction.in", "select.in" );
+    Connect( "extraction.vec", "select.out" );
+  }
+
+  void InitClassification(bool supervised)
+  {
+    if( supervised )
+      AddApplication( "TrainVectorClassifier", "training", "Model training" );
+    else
+      AddApplication( "TrainVectorClustering", "training", "Model training" );
+
+    AddParameter( ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List" );
+    SetParameterDescription( "io.valid", "A list of vector data to select the training samples." );
+    MandatoryOff( "io.valid" );
+
+    ShareClassificationParams( supervised );
+    ConnectClassificationParams();
+  };
+
+  void ShareClassificationParams(bool supervised)
+  {
+    ShareParameter( "io.imstat", "training.io.stats" );
+    ShareParameter( "io.out", "training.io.out" );
+
+    ShareParameter( "classifier", "training.classifier" );
+    ShareParameter( "rand", "training.rand" );
+
+    if( supervised )
+      ShareParameter( "io.confmatout", "training.io.confmatout" );
+  }
+
+  void ConnectClassificationParams()
+  {
+    Connect( "training.cfield", "polystat.field" );
+    Connect( "select.rand", "training.rand" );
+  }
+
+  void DoUnsupervisedInit()
+  {
+    SetName( "TrainImagesClustering" );
+    SetDescription( "Train a classifier from multiple pairs of images and training vector data." );
+
+    // Documentation
+    SetDocName( "Train a classifier from multiple images" );
+    SetDocLongDescription( "TODO" );
+    SetDocLimitations( "None" );
+    SetDocAuthors( "OTB-Team" );
+    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
+
+    AddDocTag( Tags::Learning );
+
+    ClearApplications();
+    InitSampling();
+    InitClassification( IsSupervised );
+
+    // Hide sampling parameters if sample.vnf is not provided
+    MandatoryOn( "sample.mv" );
+    MandatoryOn( "sample.mt" );
+    MandatoryOn( "sample.vtr" );
+
+
+    // Doc example parameter settings
+    SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
+    SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
+    SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
+    SetDocExampleParameterValue( "sample.mv", "100" );
+    SetDocExampleParameterValue( "sample.mt", "100" );
+    SetDocExampleParameterValue( "sample.vtr", "0.5" );
+    SetDocExampleParameterValue( "sample.vfn", "Class" );
+    SetDocExampleParameterValue( "classifier", "sharkkm" );
+    SetDocExampleParameterValue( "classifier.sharkkm.k", "2" );
+    SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" );
+  }
+
+  void DoSupervisedInit()
+  {
+    SetName( "TrainImagesClassifier" );
+    SetDescription( "Train a classifier from multiple pairs of images and training vector data." );
+
+    // Documentation
+    SetDocName( "Train a classifier from multiple images" );
+    SetDocLongDescription(
+            "This application performs a classifier training from multiple pairs of input images and training vector data. "
+                    "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by "
+                    "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field "
+                    "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation "
+                    "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio "
+                    "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and "
+                    "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the "
+                    "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. "
+                    "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels"
+                    " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning "
+                    "(2.3.1 and later)." );
+    SetDocLimitations( "None" );
+    SetDocAuthors( "OTB-Team" );
+    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
+
+    AddDocTag( Tags::Learning );
+
+    // Perform initialization
+    ClearApplications();
+    InitSampling();
+    InitClassification( IsSupervised );
+
+    // Doc example parameter settings
+    SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
+    SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
+    SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
+    SetDocExampleParameterValue( "sample.mv", "100" );
+    SetDocExampleParameterValue( "sample.mt", "100" );
+    SetDocExampleParameterValue( "sample.vtr", "0.5" );
+    SetDocExampleParameterValue( "sample.vfn", "Class" );
+    SetDocExampleParameterValue( "classifier", "sharkkm" );
+    SetDocExampleParameterValue( "classifier.sharkkm.k", "2" );
+    SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" );
+  }
+
+  void DoInit() ITK_OVERRIDE
+  {
+    //Group IO
+    AddParameter( ParameterType_Group, "io", "Input and output data" );
+    SetParameterDescription( "io", "This group of parameters allows setting input and output data." );
+
+    AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" );
+    SetParameterDescription( "io.il", "A list of input images." );
+    AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" );
+    SetParameterDescription( "io.vd", "A list of vector data to select the training samples." );
+
+    AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" );
+    EnableParameter( "cleanup" );
+    SetParameterDescription( "cleanup",
+                             "If activated, the application will try to clean all temporary files it created" );
+
+    if( IsSupervised )
+      DoSupervisedInit();
+    else
+      DoUnsupervisedInit();
+
+    MandatoryOff( "cleanup" );
+  }
+
+  void DoUpdateParameters() ITK_OVERRIDE
+  {
+    if( HasValue( "io.vd" ) )
+      {
+        std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
+        GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false );
+        UpdateInternalParameters( "polystat" );
+      }
+  }
+
+  void DoExecute() ITK_OVERRIDE
+  {
+    FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
+    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
+    unsigned long nbInputs = imageList->Size();
+
+    if( nbInputs > vectorFileList.size() )
+      {
+      otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
+      }
+
+    // check if validation vectors are given
+    std::vector<std::string> validationVectorFileList;
+    bool dedicatedValidation = false;
+    if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
+      {
+      validationVectorFileList = GetParameterStringList( "io.valid" );
+      if( nbInputs > validationVectorFileList.size() )
+        {
+        otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
+        }
+
+      if( !IsParameterEnabled( "sample.vnf" ) || !HasValue( "sample.vnf" ) )
+      otbAppLogFATAL( "Missing class field name to use validation data." );
+
+      dedicatedValidation = true;
+      }
+
+    TrainFileNamesHandler fileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );
+
+    if( !IsSupervised && IsParameterEnabled( "sample.vfn" ) && HasValue( "sample.vfn" ) )
+      {
+      fileNames.sampleTrainOutputs = vectorFileList;
+      fileNames.sampleValidOutputs = validationVectorFileList;
+      TrainModel( fileNames, imageList );
+      }
+    else
+      {
+      ComputePolygonStatistics( fileNames, imageList, dedicatedValidation, vectorFileList, validationVectorFileList );
+      SamplingRates rates = ComputeSamplingRates( dedicatedValidation );
+      SamplingRateForTrainingAndValidation( fileNames, rates, dedicatedValidation );
+      SelectAndExtractSamples( fileNames, imageList, dedicatedValidation, vectorFileList, validationVectorFileList );
+      TrainModel( fileNames, imageList );
+      }
+
+
+    // cleanup
+    if( IsParameterEnabled( "cleanup" ) )
+      {
+      otbAppLogINFO( <<"Final clean-up ..." );
+      fileNames.clear();
+      }
+  }
+
+  /**
+   * Compute polygon statistics given provided strategy
+   * \param fileNames
+   * \param imageList
+   * \param dedicatedValidation
+   */
+  void ComputePolygonStatistics(TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
+                                bool dedicatedValidation, std::vector<std::string> vectorFileList,
+                                std::vector<std::string> validationVectorFileList)
+  {
+    for( unsigned int i = 0; i < imageList->Size(); i++ )
+      {
+      GetInternalApplication( "polystat" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) );
+      GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[i], false );
+      GetInternalApplication( "polystat" )->SetParameterString( "out", fileNames.polyStatTrainOutputs[i], false );
+      ExecuteInternal( "polystat" );
+      // analyse polygons given for validation
+      if( dedicatedValidation )
+        {
+        GetInternalApplication( "polystat" )->SetParameterString( "vec", validationVectorFileList[i], false );
+        GetInternalApplication( "polystat" )->SetParameterString( "out", fileNames.polyStatValidOutputs[i], false );
+        ExecuteInternal( "polystat" );
+        }
+      }
+  }
+
+  /**
+   * Compute sampling rates
+   * \param dedicatedValidation
+   * \return SamplingRates final maximum training and final maximum validation
+   */
+  SamplingRates ComputeSamplingRates(bool dedicatedValidation)
+  {
+    SamplingRates rates;
+    GetInternalApplication( "rates" )->SetParameterString( "mim", "proportional", false );
+    double vtr = GetParameterFloat( "sample.vtr" );
+    long mt = GetParameterInt( "sample.mt" );
+    long mv = GetParameterInt( "sample.mv" );
+    // compute final maximum training and final maximum validation
+    // By default take all samples (-1 means all samples)
+    rates.fmt = -1;
+    rates.fmv = -1;
+    if( GetParameterInt( "sample.bm" ) == 0 )
+      {
+      if( dedicatedValidation )
+        {
+        // fmt and fmv will be used separately
+        rates.fmt = mt;
+        rates.fmv = mv;
+        if( mt > -1 && mv <= -1 && vtr < 0.99999 )
+          {
+          rates.fmv = static_cast<long>(( double ) mt * vtr / ( 1.0 - vtr ));
+          }
+        if( mt <= -1 && mv > -1 && vtr > 0.00001 )
+          {
+          rates.fmt = static_cast<long>(( double ) mv * ( 1.0 - vtr ) / vtr);
+          }
+        }
+      else
+        {
+        // only fmt will be used for both training and validation samples
+        // So we try to compute the total number of samples given input
+        // parameters mt, mv and vtr.
+        if( mt > -1 && mv > -1 )
+          {
+          rates.fmt = mt + mv;
+          }
+        if( mt > -1 && mv <= -1 && vtr < 0.99999 )
+          {
+          rates.fmt = static_cast<long>(( double ) mt / ( 1.0 - vtr ));
+          }
+        if( mt <= -1 && mv > -1 && vtr > 0.00001 )
+          {
+          rates.fmt = static_cast<long>(( double ) mv / vtr);
+          }
+        }
+      }
+    return rates;
+  }
+
+  /**
+   * Provide input/output images and strategy for the MultiImageSamplingRate rate application
+   * \param fileNames
+   * \param rates
+   * \param dedicatedValidation
+   */
+  void
+  SamplingRateForTrainingAndValidation(TrainFileNamesHandler &fileNames, SamplingRates rates, bool dedicatedValidation)
+  {
+    // Sampling rates for training
+    GetInternalApplication( "rates" )->SetParameterStringList( "il", fileNames.polyStatTrainOutputs, false );
+    GetInternalApplication( "rates" )->SetParameterString( "out", fileNames.rateTrainOut, false );
+    if( GetParameterInt( "sample.bm" ) != 0 )
+      {
+      GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false );
+      }
+    else
+      {
+      if( rates.fmt > -1 )
+        {
+        GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false );
+        GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(rates.fmt),
+                                                            false );
+        }
+      else
+        {
+        GetInternalApplication( "rates" )->SetParameterString( "strategy", "all", false );
+        }
+      }
+    ExecuteInternal( "rates" );
+    // Sampling rates for validation
+    if( dedicatedValidation )
+      {
+      GetInternalApplication( "rates" )->SetParameterStringList( "il", fileNames.polyStatValidOutputs, false );
+      GetInternalApplication( "rates" )->SetParameterString( "out", fileNames.rateValidOut, false );
+      if( GetParameterInt( "sample.bm" ) != 0 )
+        {
+        GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false );
+        }
+      else
+        {
+        if( rates.fmv > -1 )
+          {
+          GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false );
+          GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(rates.fmv) );
+          }
+        else
+          {
+          GetInternalApplication( "rates" )->SetParameterString( "strategy", "all", false );
+          }
+        }
+      ExecuteInternal( "rates" );
+      }
+  }
+
+  /**
+   * Configure and extract samples for the SampleExtraction application.
+   * \param fileNames
+   * \param imageList
+   * \param dedicatedValidation
+   */
+  void SelectAndExtractSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
+                               bool dedicatedValidation, const std::vector<std::string> &vectorFileList,
+                               const std::vector<std::string> &validationVectorFileList)
+  {
+    GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false );
+    GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 );
+    GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false );
+    GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false );
+    GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_", false );
+    for( unsigned int i = 0; i < imageList->Size(); ++i )
+      {
+      GetInternalApplication( "select" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) );
+      GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileList[i], false );
+      GetInternalApplication( "select" )->SetParameterString( "out", fileNames.sampleOutputs[i], false );
+      GetInternalApplication( "select" )->SetParameterString( "instats", fileNames.polyStatTrainOutputs[i], false );
+      GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", fileNames.ratesTrainOutputs[i],
+                                                              false );
+      // select sample positions
+      ExecuteInternal( "select" );
+      // extract sample descriptors
+      ExecuteInternal( "extraction" );
+
+      if( dedicatedValidation )
+        {
+        GetInternalApplication( "select" )->SetParameterString( "vec", validationVectorFileList[i], false );
+        GetInternalApplication( "select" )->SetParameterString( "out", fileNames.sampleValidOutputs[i], false );
+        GetInternalApplication( "select" )->SetParameterString( "instats", fileNames.polyStatValidOutputs[i], false );
+        GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", fileNames.ratesValidOutputs[i],
+                                                                false );
+        // select sample positions
+        ExecuteInternal( "select" );
+        // extract sample descriptors
+        ExecuteInternal( "extraction" );
+        }
+      else
+        {
+        // Split between training and validation
+        ogr::DataSource::Pointer source = ogr::DataSource::New( fileNames.sampleOutputs[i],
+                                                                ogr::DataSource::Modes::Read );
+        ogr::DataSource::Pointer destTrain = ogr::DataSource::New( fileNames.sampleTrainOutputs[i],
+                                                                   ogr::DataSource::Modes::Overwrite );
+        ogr::DataSource::Pointer destValid = ogr::DataSource::New( fileNames.sampleValidOutputs[i],
+                                                                   ogr::DataSource::Modes::Overwrite );
+        // read sampling rates from ratesTrainOutputs[i]
+        SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
+        rateCalculator->Read( fileNames.ratesTrainOutputs[i] );
+        // Compute sampling rates for train and valid
+        const MapRateType &inputRates = rateCalculator->GetRatesByClass();
+        MapRateType trainRates;
+        MapRateType validRates;
+        otb::SamplingRateCalculator::TripletType tpt;
+        for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
+          {
+          double vtr = GetParameterFloat( "sample.vtr" );
+          unsigned long total = std::min( it->second.Required, it->second.Tot );
+          unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
+          unsigned long neededTrain = total - neededValid;
+          tpt.Tot = total;
+          tpt.Required = neededTrain;
+          tpt.Rate = ( 1.0 - vtr );
+          trainRates[it->first] = tpt;
+          tpt.Tot = neededValid;
+          tpt.Required = neededValid;
+          tpt.Rate = 1.0;
+          validRates[it->first] = tpt;
+          }
+
+        // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
+        PeriodicSamplerType::SamplerParameterType param;
+        param.Offset = 0;
+        param.MaxJitter = 0;
+        PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
+        splitter->SetInput( imageList->GetNthElement( i ) );
+        splitter->SetOGRData( source );
+        splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
+        splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
+        splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] );
+        splitter->SetLayerIndex( 0 );
+        splitter->SetOriginFieldName( std::string( "" ) );
+        splitter->SetSamplerParameters( param );
+        splitter->GetStreamer()->SetAutomaticTiledStreaming(
+                static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
+        AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
+        splitter->Update();
+        }
+      }
+  }
+
+  /**
+   * Train the model with training and validation data samples
+   * \param fileNames files names used for filters
+   * \param imageList list of input images
+   */
+  void TrainModel(TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList)
+  {
+    GetInternalApplication( "training" )->SetParameterStringList( "io.vd", fileNames.sampleTrainOutputs, false );
+    GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", fileNames.sampleValidOutputs, false );
+    UpdateInternalParameters( "training" );
+    // set field names
+    FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 );
+    unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
+    std::vector<std::string> selectedNames;
+    for( unsigned int i = 0; i < nbBands; i++ )
+      {
+      std::ostringstream oss;
+      oss << i;
+      selectedNames.push_back( "value_" + oss.str() );
+      }
+    GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false );
+    ExecuteInternal( "training" );
+  }
+
+
+private:
+
+  struct SamplingRates
+  {
+    long int fmt;
+    long int fmv;
+  };
+
+  /**
+   * \class TrainFileNamesHandler
+   * This class is used to store file names requires for the application's input and output.
+   * And to clear temporary files generated by the applications
+   */
+  class TrainFileNamesHandler
+  {
+  public :
+    TrainFileNamesHandler(std::string outModel, size_t nbInputs, bool dedicatedValidation)
+    {
+
+      if( dedicatedValidation )
+        {
+        rateTrainOut = outModel + "_ratesTrain.csv";
+        }
+      else
+        {
+        rateTrainOut = outModel + "_rates.csv";
+        }
+
+      rateValidOut = outModel + "_ratesValid.csv";
+      for( unsigned int i = 0; i < nbInputs; i++ )
+        {
+        std::ostringstream oss;
+        oss << i + 1;
+        std::string strIndex( oss.str() );
+        if( dedicatedValidation )
+          {
+          polyStatTrainOutputs.push_back( outModel + "_statsTrain_" + strIndex + ".xml" );
+          polyStatValidOutputs.push_back( outModel + "_statsValid_" + strIndex + ".xml" );
+          ratesTrainOutputs.push_back( outModel + "_ratesTrain_" + strIndex + ".csv" );
+          ratesValidOutputs.push_back( outModel + "_ratesValid_" + strIndex + ".csv" );
+          sampleOutputs.push_back( outModel + "_samplesTrain_" + strIndex + ".shp" );
+          }
+        else
+          {
+          polyStatTrainOutputs.push_back( outModel + "_stats_" + strIndex + ".xml" );
+          ratesTrainOutputs.push_back( outModel + "_rates_" + strIndex + ".csv" );
+          sampleOutputs.push_back( outModel + "_samples_" + strIndex + ".shp" );
+          }
+        sampleTrainOutputs.push_back( outModel + "_samplesTrain_" + strIndex + ".shp" );
+        sampleValidOutputs.push_back( outModel + "_samplesValid_" + strIndex + ".shp" );
+        }
+
+    }
+
+    void clear()
+    {
+      for( unsigned int i = 0; i < polyStatTrainOutputs.size(); i++ )
+        RemoveFile( polyStatTrainOutputs[i] );
+      for( unsigned int i = 0; i < polyStatValidOutputs.size(); i++ )
+        RemoveFile( polyStatValidOutputs[i] );
+      for( unsigned int i = 0; i < ratesTrainOutputs.size(); i++ )
+        RemoveFile( ratesTrainOutputs[i] );
+      for( unsigned int i = 0; i < ratesValidOutputs.size(); i++ )
+        RemoveFile( ratesValidOutputs[i] );
+      for( unsigned int i = 0; i < sampleOutputs.size(); i++ )
+        RemoveFile( sampleOutputs[i] );
+      for( unsigned int i = 0; i < sampleTrainOutputs.size(); i++ )
+        RemoveFile( sampleTrainOutputs[i] );
+      for( unsigned int i = 0; i < sampleValidOutputs.size(); i++ )
+        RemoveFile( sampleValidOutputs[i] );
+    }
+
+  public:
+    std::vector<std::string> polyStatTrainOutputs;
+    std::vector<std::string> polyStatValidOutputs;
+    std::vector<std::string> ratesTrainOutputs;
+    std::vector<std::string> ratesValidOutputs;
+    std::vector<std::string> sampleOutputs;
+    std::vector<std::string> sampleTrainOutputs;
+    std::vector<std::string> sampleValidOutputs;
+    std::string rateValidOut;
+    std::string rateTrainOut;
+
+  private:
+    bool RemoveFile(std::string &filePath)
+    {
+      bool res = true;
+      if( itksys::SystemTools::FileExists( filePath.c_str() ) )
+        {
+        size_t posExt = filePath.rfind( '.' );
+        if( posExt != std::string::npos && filePath.compare( posExt, std::string::npos, ".shp" ) == 0 )
+          {
+          std::string shxPath = filePath.substr( 0, posExt ) + std::string( ".shx" );
+          std::string dbfPath = filePath.substr( 0, posExt ) + std::string( ".dbf" );
+          std::string prjPath = filePath.substr( 0, posExt ) + std::string( ".prj" );
+          RemoveFile( shxPath );
+          RemoveFile( dbfPath );
+          RemoveFile( prjPath );
+          }
+        res = itksys::SystemTools::RemoveFile( filePath.c_str() );
+        if( !res )
+          {
+          //otbAppLogINFO( <<"Unable to remove file  "<<filePath );
+          }
+        }
+      return res;
+    }
+  };
+
+};
+
+} // end namespace Wrapper
+} // end namespace otb
+
+
+#endif //otbTrainImagesBase_h
-- 
GitLab