From 90c054355860267f7b891f07c3fa6fa5bf0322b5 Mon Sep 17 00:00:00 2001
From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr>
Date: Thu, 23 Feb 2017 14:39:45 +0100
Subject: [PATCH] ENH: Select strategy depending on provided Vector and do some
 refac.

---
 .../app/otbTrainImagesClassifier.cxx          | 119 +++-
 .../app/otbTrainImagesClustering.cxx          | 177 +++++-
 .../app/otbTrainVectorClustering.cxx          |   5 +-
 .../include/otbTrainImagesBase.h              | 525 +++++++-----------
 4 files changed, 495 insertions(+), 331 deletions(-)

diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx
index 4a3b98b206..3ed942dbc3 100644
--- a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx
@@ -5,15 +5,130 @@ namespace otb
 namespace Wrapper
 {
 
-class TrainImagesClassifier : public TrainImagesBase<true>
+class TrainImagesClassifier : public TrainImagesBase
 {
 public:
   typedef TrainImagesClassifier         Self;
-  typedef TrainImagesBase<true>         Superclass;
+  typedef TrainImagesBase               Superclass;
   typedef itk::SmartPointer<Self>       Pointer;
   typedef itk::SmartPointer<const Self> ConstPointer;
   itkNewMacro( Self )
   itkTypeMacro( Self, Superclass )
+
+  void DoInit() ITK_OVERRIDE
+  {
+    SetName( "TrainImagesClassifier" );
+    SetDescription( "Train a classifier from multiple pairs of images and training vector data." );
+
+    // Documentation
+    SetDocName( "Train a classifier from multiple images" );
+    SetDocLongDescription(
+            "This application performs a classifier training from multiple pairs of input images and training vector data. "
+                    "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by "
+                    "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field "
+                    "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation "
+                    "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio "
+                    "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and "
+                    "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the "
+                    "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. "
+                    "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels"
+                    " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning "
+                    "(2.3.1 and later)." );
+    SetDocLimitations( "None" );
+    SetDocAuthors( "OTB-Team" );
+    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
+
+    AddDocTag( Tags::Learning );
+
+    // Perform initialization
+    ClearApplications();
+    InitIO();
+    InitSampling();
+    InitClassification( true );
+
+
+    // Doc example parameter settings
+    SetDocExampleParameterValue("io.il", "QB_1_ortho.tif");
+    SetDocExampleParameterValue("io.vd", "VectorData_QB1.shp");
+    SetDocExampleParameterValue("io.imstat", "EstimateImageStatisticsQB1.xml");
+    SetDocExampleParameterValue("sample.mv", "100");
+    SetDocExampleParameterValue("sample.mt", "100");
+    SetDocExampleParameterValue("sample.vtr", "0.5");
+    SetDocExampleParameterValue("sample.vfn", "Class");
+    SetDocExampleParameterValue("classifier", "libsvm");
+    SetDocExampleParameterValue("classifier.libsvm.k", "linear");
+    SetDocExampleParameterValue("classifier.libsvm.c", "1");
+    SetDocExampleParameterValue("classifier.libsvm.opt", "false");
+    SetDocExampleParameterValue("io.out", "svmModelQB1.txt");
+    SetDocExampleParameterValue("io.confmatout", "svmConfusionMatrixQB1.csv");
+  }
+
+  void DoUpdateParameters() ITK_OVERRIDE
+  {
+    if( HasValue( "io.vd" ) )
+      {
+      std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
+      GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false );
+      UpdateInternalParameters( "polystat" );
+      }
+  }
+
+  void DoExecute() ITK_OVERRIDE
+  {
+    TrainFileNamesHandler fileNames;
+    FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
+    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
+    unsigned long nbInputs = imageList->Size();
+
+    if( nbInputs > vectorFileList.size() )
+      {
+      otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
+      }
+
+    // check if validation vectors are given
+    std::vector<std::string> validationVectorFileList;
+    bool dedicatedValidation = false;
+    if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
+      {
+      validationVectorFileList = GetParameterStringList( "io.valid" );
+      if( nbInputs > validationVectorFileList.size() )
+        {
+        otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
+        }
+
+      dedicatedValidation = true;
+      }
+
+    fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );
+
+    // Compute final maximum sampling rates for both training and validation samples
+    SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );
+
+    // Select and Extract samples for training with computed statistics and rates
+    ComputePolygonStatistics(imageList, vectorFileList, fileNames.polyStatTrainOutputs);
+    ComputeSamplingRate(fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt);
+    SelectAndExtractTrainSamples(fileNames, imageList, vectorFileList, SamplingStrategy::CLASS);
+
+    // Select and Extract samples for validation with computed statistics and rates
+    // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided
+    if( dedicatedValidation ) {
+      ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
+      ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
+      }
+    SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
+
+
+    // Then train the model with extracted samples
+    TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs);
+
+    // cleanup
+    if( IsParameterEnabled( "cleanup" ) )
+      {
+      otbAppLogINFO( <<"Final clean-up ..." );
+      fileNames.clear();
+      }
+  }
+
 };
 
 }
diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx
index fed5b0775a..fdabcd1b08 100644
--- a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx
@@ -5,15 +5,188 @@ namespace otb
 namespace Wrapper
 {
 
-class TrainImagesClustering : public TrainImagesBase<false>
+class TrainImagesClustering : public TrainImagesBase
 {
 public:
   typedef TrainImagesClustering         Self;
-  typedef TrainImagesBase<false>        Superclass;
+  typedef TrainImagesBase               Superclass;
   typedef itk::SmartPointer<Self>       Pointer;
   typedef itk::SmartPointer<const Self> ConstPointer;
   itkNewMacro( Self )
   itkTypeMacro( Self, Superclass )
+
+  void DoInit() ITK_OVERRIDE
+  {
+    SetName( "TrainImagesClustering" );
+    SetDescription( "Train a classifier from multiple pairs of images and optional input training vector data." );
+
+    // Documentation
+    SetDocName( "Train a classifier from multiple images" );
+    SetDocLongDescription( "TODO" );
+    SetDocLimitations( "None" );
+    SetDocAuthors( "OTB-Team" );
+    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
+
+    AddDocTag( Tags::Learning );
+
+    ClearApplications();
+    InitIO();
+    InitSampling();
+    InitClassification( false );
+
+    // Doc example parameter settings
+    SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
+    SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
+    SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
+    SetDocExampleParameterValue( "sample.mv", "100" );
+    SetDocExampleParameterValue( "sample.mt", "100" );
+    SetDocExampleParameterValue( "sample.vtr", "0.5" );
+    SetDocExampleParameterValue( "sample.vfn", "Class" );
+    SetDocExampleParameterValue( "classifier", "sharkkm" );
+    SetDocExampleParameterValue( "classifier.sharkkm.k", "2" );
+    SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" );
+  }
+
+  void DoUpdateParameters() ITK_OVERRIDE
+  {
+    if( HasValue( "io.vd" ) )
+      {
+      UpdatePolygonClassStatisticsParameters();
+      }
+  }
+
+  void DoExecute() ITK_OVERRIDE
+  {
+    TrainFileNamesHandler fileNames;
+    FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
+    bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
+    std::vector<std::string> vectorFileList = GetVectorFileList( GetParameterString( "io.out" ), fileNames );
+
+
+    unsigned long nbInputs = imageList->Size();
+
+    if( nbInputs > vectorFileList.size() )
+      {
+      otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
+      }
+
+    // check if validation vectors are given
+    std::vector<std::string> validationVectorFileList;
+    bool dedicatedValidation = false;
+    if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
+      {
+      validationVectorFileList = GetParameterStringList( "io.valid" );
+      if( nbInputs > validationVectorFileList.size() )
+        {
+        otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
+        }
+
+      dedicatedValidation = true;
+      }
+
+    fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );
+
+    // Compute final maximum sampling rates for both training and validation samples
+    SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );
+
+    if( HasInputVector )
+    {
+      // Select and Extract samples for training with computed statistics and rates
+      ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs );
+      ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt );
+      SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS );
+    }
+    else
+    {
+      SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC );
+    }
+
+    // Select and Extract samples for validation with computed statistics and rates
+    // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided
+    if( dedicatedValidation ) {
+      ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
+      ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
+      }
+    SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
+
+
+    // Then train the model with extracted samples
+    TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs);
+
+    // cleanup
+    if( IsParameterEnabled( "cleanup" ) )
+      {
+      otbAppLogINFO( <<"Final clean-up ..." );
+      fileNames.clear();
+      }
+  }
+
+private :
+
+  void UpdatePolygonClassStatisticsParameters()
+  {
+    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
+    GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false );
+    UpdateInternalParameters( "polystat" );
+  }
+
+
+  /**
+   * Retrieve input vector data if provided otherwise generate a default vector shape file for each image.
+   * \param output vector file path
+   * \param fileNames
+   * \return list of input vector data file names
+   */
+  std::vector<std::string> GetVectorFileList(std::string output, TrainFileNamesHandler &fileNames)
+  {
+    std::vector<std::string> vectorFileList;
+    bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
+
+    // Retrieve provided input vector data if available.
+    if( !HasInputVector )
+      {
+      FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
+      unsigned int nbInputs = static_cast<unsigned int>(imageList->Size());
+
+      for( unsigned int i = 0; i < nbInputs; ++i )
+        {
+        std::string name = output + "_vector_" + std::to_string( i ) + ".shp";
+        GenerateVectorDataFile( imageList->GetNthElement( i ), name );
+        fileNames.tmpVectorFileList.push_back( name );
+        }
+      vectorFileList = fileNames.tmpVectorFileList;
+      SetParameterStringList( "io.vd", vectorFileList, false );
+      UpdatePolygonClassStatisticsParameters();
+      GetInternalApplication( "polystat" )->SetParameterString( "field", "fid" );
+      }
+    else
+      {
+      vectorFileList = GetParameterStringList( "io.vd" );
+      }
+
+    return vectorFileList;
+  }
+
+
+
+  void GenerateVectorDataFile(const FloatVectorImageListType::ObjectPointerType &floatVectorImage, std::string name)
+  {
+    typedef otb::ImageToEnvelopeVectorDataFilter<FloatVectorImageType, VectorDataType> ImageToEnvelopeFilterType;
+    typedef ImageToEnvelopeFilterType::OutputVectorDataType OutputVectorData;
+    typedef otb::VectorDataFileWriter<OutputVectorData> VectorDataWriter;
+
+    ImageToEnvelopeFilterType::Pointer imageToEnvelopeVectorData = ImageToEnvelopeFilterType::New();
+    imageToEnvelopeVectorData->SetInput( floatVectorImage );
+    imageToEnvelopeVectorData->SetOutputProjectionRef( floatVectorImage->GetProjectionRef().c_str() );
+    OutputVectorData::Pointer vectorData = imageToEnvelopeVectorData->GetOutput();
+
+    // write temporary generated vector file to disk.
+    VectorDataWriter::Pointer vectorDataFileWriter = VectorDataWriter::New();
+    vectorDataFileWriter->SetInput( vectorData );
+    vectorDataFileWriter->SetFileName( name.c_str() );
+    vectorDataFileWriter->Write();
+  }
+
 };
 
 }
diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx
index 49acbbc2b3..596dbef867 100644
--- a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx
@@ -57,9 +57,8 @@ private:
 
     // Doc example parameter settings
     SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
-    SetDocExampleParameterValue( "io.stats", "meanVar.xml" );
-    SetDocExampleParameterValue( "io.out", "svmModel.svm" );
-    SetDocExampleParameterValue( "feat", "perimeter  area  width" );
+    SetDocExampleParameterValue( "io.out", "kmeansModel.txt" );
+    SetDocExampleParameterValue( "feat", "perimeter width area" );
 
   }
 
diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
index 5b6aca1460..4f5dd82d3c 100644
--- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
+++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
@@ -17,18 +17,21 @@
 #ifndef otbTrainImagesBase_h
 #define otbTrainImagesBase_h
 
+
+#include "otbVectorDataFileWriter.h"
 #include "otbWrapperCompositeApplication.h"
 #include "otbWrapperApplicationFactory.h"
 
-#include "otbOGRDataToSamplePositionFilter.h"
+#include "otbStatisticsXMLFileWriter.h"
+#include "otbImageToEnvelopeVectorDataFilter.h"
 #include "otbSamplingRateCalculator.h"
+#include "otbOGRDataToSamplePositionFilter.h"
 
 namespace otb
 {
 namespace Wrapper
 {
 
-template<bool IsSupervised = true>
 class TrainImagesBase : public CompositeApplication
 {
 public:
@@ -48,11 +51,32 @@ public:
 
 protected:
 
-private:
-  struct SamplingRates;
+  enum SamplingStrategy
+  {
+    CLASS, GEOMETRIC
+  };
 
+  struct SamplingRates;
   class TrainFileNamesHandler;
 
+  void InitIO()
+  {
+    //Group IO
+    AddParameter( ParameterType_Group, "io", "Input and output data" );
+    SetParameterDescription( "io", "This group of parameters allows setting input and output data." );
+
+    AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" );
+    SetParameterDescription( "io.il", "A list of input images." );
+    AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" );
+    SetParameterDescription( "io.vd", "A list of vector data to select the training samples." );
+
+    AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" );
+    EnableParameter( "cleanup" );
+    SetParameterDescription( "cleanup",
+                             "If activated, the application will try to clean all temporary files it created" );
+    MandatoryOff( "cleanup" );
+  }
+
   void InitSampling()
   {
     AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" );
@@ -131,6 +155,9 @@ private:
     SetParameterDescription( "io.valid", "A list of vector data to select the training samples." );
     MandatoryOff( "io.valid" );
 
+    if( !supervised )
+      MandatoryOff( "io.vd" );
+
     ShareClassificationParams( supervised );
     ConnectClassificationParams();
   };
@@ -153,206 +180,31 @@ private:
     Connect( "select.rand", "training.rand" );
   }
 
-  void DoUnsupervisedInit()
-  {
-    SetName( "TrainImagesClustering" );
-    SetDescription( "Train a classifier from multiple pairs of images and training vector data." );
-
-    // Documentation
-    SetDocName( "Train a classifier from multiple images" );
-    SetDocLongDescription( "TODO" );
-    SetDocLimitations( "None" );
-    SetDocAuthors( "OTB-Team" );
-    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
-
-    AddDocTag( Tags::Learning );
-
-    ClearApplications();
-    InitSampling();
-    InitClassification( IsSupervised );
-
-    // Hide sampling parameters if sample.vnf is not provided
-    MandatoryOn( "sample.mv" );
-    MandatoryOn( "sample.mt" );
-    MandatoryOn( "sample.vtr" );
-
-
-    // Doc example parameter settings
-    SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
-    SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
-    SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
-    SetDocExampleParameterValue( "sample.mv", "100" );
-    SetDocExampleParameterValue( "sample.mt", "100" );
-    SetDocExampleParameterValue( "sample.vtr", "0.5" );
-    SetDocExampleParameterValue( "sample.vfn", "Class" );
-    SetDocExampleParameterValue( "classifier", "sharkkm" );
-    SetDocExampleParameterValue( "classifier.sharkkm.k", "2" );
-    SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" );
-  }
-
-  void DoSupervisedInit()
-  {
-    SetName( "TrainImagesClassifier" );
-    SetDescription( "Train a classifier from multiple pairs of images and training vector data." );
-
-    // Documentation
-    SetDocName( "Train a classifier from multiple images" );
-    SetDocLongDescription(
-            "This application performs a classifier training from multiple pairs of input images and training vector data. "
-                    "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by "
-                    "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field "
-                    "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation "
-                    "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio "
-                    "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and "
-                    "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the "
-                    "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. "
-                    "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels"
-                    " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning "
-                    "(2.3.1 and later)." );
-    SetDocLimitations( "None" );
-    SetDocAuthors( "OTB-Team" );
-    SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
-
-    AddDocTag( Tags::Learning );
-
-    // Perform initialization
-    ClearApplications();
-    InitSampling();
-    InitClassification( IsSupervised );
-
-    // Doc example parameter settings
-    SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
-    SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
-    SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
-    SetDocExampleParameterValue( "sample.mv", "100" );
-    SetDocExampleParameterValue( "sample.mt", "100" );
-    SetDocExampleParameterValue( "sample.vtr", "0.5" );
-    SetDocExampleParameterValue( "sample.vfn", "Class" );
-    SetDocExampleParameterValue( "classifier", "sharkkm" );
-    SetDocExampleParameterValue( "classifier.sharkkm.k", "2" );
-    SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" );
-  }
-
-  void DoInit() ITK_OVERRIDE
-  {
-    //Group IO
-    AddParameter( ParameterType_Group, "io", "Input and output data" );
-    SetParameterDescription( "io", "This group of parameters allows setting input and output data." );
-
-    AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" );
-    SetParameterDescription( "io.il", "A list of input images." );
-    AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" );
-    SetParameterDescription( "io.vd", "A list of vector data to select the training samples." );
-
-    AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" );
-    EnableParameter( "cleanup" );
-    SetParameterDescription( "cleanup",
-                             "If activated, the application will try to clean all temporary files it created" );
-
-    if( IsSupervised )
-      DoSupervisedInit();
-    else
-      DoUnsupervisedInit();
-
-    MandatoryOff( "cleanup" );
-  }
-
-  void DoUpdateParameters() ITK_OVERRIDE
-  {
-    if( HasValue( "io.vd" ) )
-      {
-        std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
-        GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false );
-        UpdateInternalParameters( "polystat" );
-      }
-  }
-
-  void DoExecute() ITK_OVERRIDE
-  {
-    FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
-    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
-    unsigned long nbInputs = imageList->Size();
-
-    if( nbInputs > vectorFileList.size() )
-      {
-      otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
-      }
-
-    // check if validation vectors are given
-    std::vector<std::string> validationVectorFileList;
-    bool dedicatedValidation = false;
-    if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
-      {
-      validationVectorFileList = GetParameterStringList( "io.valid" );
-      if( nbInputs > validationVectorFileList.size() )
-        {
-        otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
-        }
-
-      if( !IsParameterEnabled( "sample.vnf" ) || !HasValue( "sample.vnf" ) )
-      otbAppLogFATAL( "Missing class field name to use validation data." );
-
-      dedicatedValidation = true;
-      }
-
-    TrainFileNamesHandler fileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );
-
-    if( !IsSupervised && IsParameterEnabled( "sample.vfn" ) && HasValue( "sample.vfn" ) )
-      {
-      fileNames.sampleTrainOutputs = vectorFileList;
-      fileNames.sampleValidOutputs = validationVectorFileList;
-      TrainModel( fileNames, imageList );
-      }
-    else
-      {
-      ComputePolygonStatistics( fileNames, imageList, dedicatedValidation, vectorFileList, validationVectorFileList );
-      SamplingRates rates = ComputeSamplingRates( dedicatedValidation );
-      SamplingRateForTrainingAndValidation( fileNames, rates, dedicatedValidation );
-      SelectAndExtractSamples( fileNames, imageList, dedicatedValidation, vectorFileList, validationVectorFileList );
-      TrainModel( fileNames, imageList );
-      }
-
-
-    // cleanup
-    if( IsParameterEnabled( "cleanup" ) )
-      {
-      otbAppLogINFO( <<"Final clean-up ..." );
-      fileNames.clear();
-      }
-  }
-
   /**
-   * Compute polygon statistics given provided strategy
-   * \param fileNames
-   * \param imageList
-   * \param dedicatedValidation
+   * Compute polygon statistics given provided strategy with PolygonClassStatistics class
+   * \param imageList list of input images
+   * \param vectorFileNames list of input vector file names
+   * \param statisticsFileNames list of out
    */
-  void ComputePolygonStatistics(TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
-                                bool dedicatedValidation, std::vector<std::string> vectorFileList,
-                                std::vector<std::string> validationVectorFileList)
+  void ComputePolygonStatistics(FloatVectorImageListType *imageList, const std::vector<std::string> &vectorFileNames,
+                                const std::vector<std::string> &statisticsFileNames)
   {
-    for( unsigned int i = 0; i < imageList->Size(); i++ )
+    unsigned int nbImages = static_cast<unsigned int>(imageList->Size());
+    for( unsigned int i = 0; i < nbImages; i++ )
       {
       GetInternalApplication( "polystat" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) );
-      GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[i], false );
-      GetInternalApplication( "polystat" )->SetParameterString( "out", fileNames.polyStatTrainOutputs[i], false );
+      GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileNames[i], false );
+      GetInternalApplication( "polystat" )->SetParameterString( "out", statisticsFileNames[i], false );
       ExecuteInternal( "polystat" );
-      // analyse polygons given for validation
-      if( dedicatedValidation )
-        {
-        GetInternalApplication( "polystat" )->SetParameterString( "vec", validationVectorFileList[i], false );
-        GetInternalApplication( "polystat" )->SetParameterString( "out", fileNames.polyStatValidOutputs[i], false );
-        ExecuteInternal( "polystat" );
-        }
       }
   }
 
   /**
-   * Compute sampling rates
+   * Compute final maximum training and validation
    * \param dedicatedValidation
    * \return SamplingRates final maximum training and final maximum validation
    */
-  SamplingRates ComputeSamplingRates(bool dedicatedValidation)
+  SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation)
   {
     SamplingRates rates;
     GetInternalApplication( "rates" )->SetParameterString( "mim", "proportional", false );
@@ -401,29 +253,30 @@ private:
     return rates;
   }
 
+
   /**
-   * Provide input/output images and strategy for the MultiImageSamplingRate rate application
-   * \param fileNames
-   * \param rates
-   * \param dedicatedValidation
+   * Compute rates using MultiImageSamplingRate application
+   * \param statisticsFileNames
+   * \param ratesFileName
+   * \param maximum final maximum value computed by ComputeFinalMaximumSamplingRates
+   * \sa ComputeFinalMaximumSamplingRates
    */
-  void
-  SamplingRateForTrainingAndValidation(TrainFileNamesHandler &fileNames, SamplingRates rates, bool dedicatedValidation)
+  void ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames, const std::string &ratesFileName,
+                           long maximum)
   {
-    // Sampling rates for training
-    GetInternalApplication( "rates" )->SetParameterStringList( "il", fileNames.polyStatTrainOutputs, false );
-    GetInternalApplication( "rates" )->SetParameterString( "out", fileNames.rateTrainOut, false );
+    // Sampling rates
+    GetInternalApplication( "rates" )->SetParameterStringList( "il", statisticsFileNames, false );
+    GetInternalApplication( "rates" )->SetParameterString( "out", ratesFileName, false );
     if( GetParameterInt( "sample.bm" ) != 0 )
       {
       GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false );
       }
     else
       {
-      if( rates.fmt > -1 )
+      if( maximum > -1 )
         {
         GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false );
-        GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(rates.fmt),
-                                                            false );
+        GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(maximum), false );
         }
       else
         {
@@ -431,151 +284,172 @@ private:
         }
       }
     ExecuteInternal( "rates" );
-    // Sampling rates for validation
-    if( dedicatedValidation )
+  }
+
+  /**
+   * Train the model with training and optional validation data samples
+   * \param imageList list of input images
+   * \param sampleTrainFileNames files names of the training samples
+   * \param sampleValidationFileNames file names of the validation sample
+   */
+  void TrainModel(FloatVectorImageListType *imageList, const std::vector<std::string> &sampleTrainFileNames,
+                  const std::vector<std::string> &sampleValidationFileNames)
+  {
+    GetInternalApplication( "training" )->SetParameterStringList( "io.vd", sampleTrainFileNames, false );
+    if( !sampleValidationFileNames.empty() )
+      GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", sampleValidationFileNames, false );
+
+    UpdateInternalParameters( "training" );
+    // set field names
+    FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 );
+    unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
+    std::vector<std::string> selectedNames;
+    for( unsigned int i = 0; i < nbBands; i++ )
       {
-      GetInternalApplication( "rates" )->SetParameterStringList( "il", fileNames.polyStatValidOutputs, false );
-      GetInternalApplication( "rates" )->SetParameterString( "out", fileNames.rateValidOut, false );
-      if( GetParameterInt( "sample.bm" ) != 0 )
-        {
-        GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false );
-        }
-      else
-        {
-        if( rates.fmv > -1 )
-          {
-          GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false );
-          GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(rates.fmv) );
-          }
-        else
-          {
-          GetInternalApplication( "rates" )->SetParameterString( "strategy", "all", false );
-          }
-        }
-      ExecuteInternal( "rates" );
+      std::ostringstream oss;
+      oss << i;
+      selectedNames.push_back( "value_" + oss.str() );
       }
+    GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false );
+    ExecuteInternal( "training" );
   }
 
   /**
-   * Configure and extract samples for the SampleExtraction application.
-   * \param fileNames
-   * \param imageList
-   * \param dedicatedValidation
+   * Select samples by class or by geographic strategy
+   * \param image
+   * \param vectorFileName
+   * \param sampleFileName
+   * \param statisticsFileName
+   * \param ratesFileName
+   * \param strategy
    */
-  void SelectAndExtractSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
-                               bool dedicatedValidation, const std::vector<std::string> &vectorFileList,
-                               const std::vector<std::string> &validationVectorFileList)
+  void SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, std::string sampleFileName,
+                               std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy)
   {
-    GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false );
-    GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 );
-    GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false );
+    GetInternalApplication( "select" )->SetParameterInputImage( "in", image );
+    GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false );
+    GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false );
+
     GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false );
     GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_", false );
+
+    // Change the selection strategy based on selected sampling strategy
+    switch( strategy )
+      {
+      case GEOMETRIC:
+        GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false );
+        GetInternalApplication( "select" )->SetParameterString( "strategy", "all", false );
+        break;
+      case CLASS:
+      default:
+        GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false );
+        GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false );
+        GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 );
+        GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false );
+        GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", ratesFileName, false );
+        break;
+      }
+
+    // select sample positions
+    ExecuteInternal( "select" );
+    // extract sample descriptors
+    ExecuteInternal( "extraction" );
+  }
+
+  /**
+   * Select and extract samples with the SampleSelection and SampleExtraction application.
+   */
+  void SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
+                                    std::vector<std::string> vectorFileNames, SamplingStrategy strategy)
+  {
+
     for( unsigned int i = 0; i < imageList->Size(); ++i )
       {
-      GetInternalApplication( "select" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) );
-      GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileList[i], false );
-      GetInternalApplication( "select" )->SetParameterString( "out", fileNames.sampleOutputs[i], false );
-      GetInternalApplication( "select" )->SetParameterString( "instats", fileNames.polyStatTrainOutputs[i], false );
-      GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", fileNames.ratesTrainOutputs[i],
-                                                              false );
-      // select sample positions
-      ExecuteInternal( "select" );
-      // extract sample descriptors
-      ExecuteInternal( "extraction" );
+      SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileNames[i], fileNames.sampleOutputs[i],
+                               fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy );
+      }
+  }
 
-      if( dedicatedValidation )
+
+  void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
+                                         const std::vector<std::string> &validationVectorFileList,
+                                         bool dedicatedValidation)
+  {
+    // In dedicated validation mode the by class sampling strategy and statistics are used.
+    // Otherwise simply split training to validation samples corresponding to sample.vtr percentage.
+    if( dedicatedValidation )
+      {
+      for( unsigned int i = 0; i < imageList->Size(); ++i )
         {
-        GetInternalApplication( "select" )->SetParameterString( "vec", validationVectorFileList[i], false );
-        GetInternalApplication( "select" )->SetParameterString( "out", fileNames.sampleValidOutputs[i], false );
-        GetInternalApplication( "select" )->SetParameterString( "instats", fileNames.polyStatValidOutputs[i], false );
-        GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", fileNames.ratesValidOutputs[i],
-                                                                false );
-        // select sample positions
-        ExecuteInternal( "select" );
-        // extract sample descriptors
-        ExecuteInternal( "extraction" );
+        SelectAndExtractSamples( imageList->GetNthElement( i ), validationVectorFileList[i],
+                                 fileNames.sampleValidOutputs[i], fileNames.polyStatValidOutputs[i],
+                                 fileNames.ratesValidOutputs[i], SamplingStrategy::CLASS );
         }
-      else
+      }
+    else
+      {
+      for( unsigned int i = 0; i < imageList->Size(); ++i )
         {
-        // Split between training and validation
-        ogr::DataSource::Pointer source = ogr::DataSource::New( fileNames.sampleOutputs[i],
-                                                                ogr::DataSource::Modes::Read );
-        ogr::DataSource::Pointer destTrain = ogr::DataSource::New( fileNames.sampleTrainOutputs[i],
-                                                                   ogr::DataSource::Modes::Overwrite );
-        ogr::DataSource::Pointer destValid = ogr::DataSource::New( fileNames.sampleValidOutputs[i],
-                                                                   ogr::DataSource::Modes::Overwrite );
-        // read sampling rates from ratesTrainOutputs[i]
-        SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
-        rateCalculator->Read( fileNames.ratesTrainOutputs[i] );
-        // Compute sampling rates for train and valid
-        const MapRateType &inputRates = rateCalculator->GetRatesByClass();
-        MapRateType trainRates;
-        MapRateType validRates;
-        otb::SamplingRateCalculator::TripletType tpt;
-        for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
-          {
-          double vtr = GetParameterFloat( "sample.vtr" );
-          unsigned long total = std::min( it->second.Required, it->second.Tot );
-          unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
-          unsigned long neededTrain = total - neededValid;
-          tpt.Tot = total;
-          tpt.Required = neededTrain;
-          tpt.Rate = ( 1.0 - vtr );
-          trainRates[it->first] = tpt;
-          tpt.Tot = neededValid;
-          tpt.Required = neededValid;
-          tpt.Rate = 1.0;
-          validRates[it->first] = tpt;
-          }
-
-        // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
-        PeriodicSamplerType::SamplerParameterType param;
-        param.Offset = 0;
-        param.MaxJitter = 0;
-        PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
-        splitter->SetInput( imageList->GetNthElement( i ) );
-        splitter->SetOGRData( source );
-        splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
-        splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
-        splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] );
-        splitter->SetLayerIndex( 0 );
-        splitter->SetOriginFieldName( std::string( "" ) );
-        splitter->SetSamplerParameters( param );
-        splitter->GetStreamer()->SetAutomaticTiledStreaming(
-                static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
-        AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
-        splitter->Update();
+        SplitTrainingAndValidationSamples( imageList->GetNthElement( i ), fileNames.sampleOutputs[i],
+                                           fileNames.sampleTrainOutputs[i], fileNames.sampleValidOutputs[i],
+                                           fileNames.ratesTrainOutputs[i] );
         }
       }
   }
 
-  /**
-   * Train the model with training and validation data samples
-   * \param fileNames files names used for filters
-   * \param imageList list of input images
-   */
-  void TrainModel(TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList)
+private:
+  void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName,
+                                         std::string sampleTrainFileName, std::string sampleValidFileName,
+                                         std::string ratesTrainFileName)
   {
-    GetInternalApplication( "training" )->SetParameterStringList( "io.vd", fileNames.sampleTrainOutputs, false );
-    GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", fileNames.sampleValidOutputs, false );
-    UpdateInternalParameters( "training" );
-    // set field names
-    FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 );
-    unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
-    std::vector<std::string> selectedNames;
-    for( unsigned int i = 0; i < nbBands; i++ )
+    // Split between training and validation
+    ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read );
+    ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite );
+    ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite );
+    // read sampling rates from ratesTrainOutputs
+    SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
+    rateCalculator->Read( ratesTrainFileName );
+    // Compute sampling rates for train and valid
+    const MapRateType &inputRates = rateCalculator->GetRatesByClass();
+    MapRateType trainRates;
+    MapRateType validRates;
+    otb::SamplingRateCalculator::TripletType tpt;
+    for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
       {
-      std::ostringstream oss;
-      oss << i;
-      selectedNames.push_back( "value_" + oss.str() );
+      double vtr = GetParameterFloat( "sample.vtr" );
+      unsigned long total = std::min( it->second.Required, it->second.Tot );
+      unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
+      unsigned long neededTrain = total - neededValid;
+      tpt.Tot = total;
+      tpt.Required = neededTrain;
+      tpt.Rate = ( 1.0 - vtr );
+      trainRates[it->first] = tpt;
+      tpt.Tot = neededValid;
+      tpt.Required = neededValid;
+      tpt.Rate = 1.0;
+      validRates[it->first] = tpt;
       }
-    GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false );
-    ExecuteInternal( "training" );
+
+    // Use an otb::OGRDataToSamplePositionFilter with 2 outputs
+    PeriodicSamplerType::SamplerParameterType param;
+    param.Offset = 0;
+    param.MaxJitter = 0;
+    PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
+    splitter->SetInput( image );
+    splitter->SetOGRData( source );
+    splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
+    splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
+    splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] );
+    splitter->SetLayerIndex( 0 );
+    splitter->SetOriginFieldName( std::string( "" ) );
+    splitter->SetSamplerParameters( param );
+    splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
+    AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
+    splitter->Update();
   }
 
 
-private:
+protected:
 
   struct SamplingRates
   {
@@ -591,7 +465,7 @@ private:
   class TrainFileNamesHandler
   {
   public :
-    TrainFileNamesHandler(std::string outModel, size_t nbInputs, bool dedicatedValidation)
+    void CreateTemporaryFileNames(std::string outModel, size_t nbInputs, bool dedicatedValidation)
     {
 
       if( dedicatedValidation )
@@ -645,6 +519,8 @@ private:
         RemoveFile( sampleTrainOutputs[i] );
       for( unsigned int i = 0; i < sampleValidOutputs.size(); i++ )
         RemoveFile( sampleValidOutputs[i] );
+      for( unsigned int i = 0; i < tmpVectorFileList.size(); i++ )
+        RemoveFile( tmpVectorFileList[i] );
     }
 
   public:
@@ -655,6 +531,7 @@ private:
     std::vector<std::string> sampleOutputs;
     std::vector<std::string> sampleTrainOutputs;
     std::vector<std::string> sampleValidOutputs;
+    std::vector<std::string> tmpVectorFileList;
     std::string rateValidOut;
     std::string rateTrainOut;
 
-- 
GitLab