From 078b7a5a0b4c25512c9d993251b558d1b1d2e54f Mon Sep 17 00:00:00 2001
From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr>
Date: Thu, 23 Feb 2017 16:55:15 +0100
Subject: [PATCH] ENH: Update TrainImagesClustering and Classifier to use
 SampleSelection.

SampleSelection can now be used without input vector data.
---
 .../app/otbTrainImagesClassifier.cxx          |  2 +-
 .../app/otbTrainImagesClustering.cxx          | 93 +++++--------------
 .../include/otbTrainImagesBase.h              | 13 +--
 3 files changed, 33 insertions(+), 75 deletions(-)

diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx
index 3ed942dbc3..bb0f9176b2 100644
--- a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx
@@ -115,7 +115,7 @@ public:
       ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
       ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
       }
-    SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
+    SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList);
 
 
     // Then train the model with extracted samples
diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx
index fdabcd1b08..e4ffd3b5db 100644
--- a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx
@@ -34,6 +34,14 @@ public:
     InitSampling();
     InitClassification( false );
 
+    AddParameter( ParameterType_Float, "sample.percent", "Percentage of samples extract in images for "
+            "training and validation when only images are provided." );
+    SetParameterDescription( "sample.percent", "Percentage of samples extract in images for "
+            "training and validation when only images are provided. This parameter is disable when vector data are provided" );
+    SetDefaultParameterFloat( "sample.percent", 100.0 );
+    SetMinimumParameterFloatValue( "sample.percent", 0.0 );
+    SetMaximumParameterFloatValue( "sample.percent", 100.0 );
+
     // Doc example parameter settings
     SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
     SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
@@ -51,8 +59,13 @@ public:
   {
     if( HasValue( "io.vd" ) )
       {
+      MandatoryOff( "sample.percent" );
       UpdatePolygonClassStatisticsParameters();
       }
+    else
+      {
+      MandatoryOn( "sample.percent" );
+      }
   }
 
   void DoExecute() ITK_OVERRIDE
@@ -60,12 +73,12 @@ public:
     TrainFileNamesHandler fileNames;
     FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
     bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
-    std::vector<std::string> vectorFileList = GetVectorFileList( GetParameterString( "io.out" ), fileNames );
+    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
 
 
     unsigned long nbInputs = imageList->Size();
 
-    if( nbInputs > vectorFileList.size() )
+    if( !vectorFileList.empty() && nbInputs > vectorFileList.size() )
       {
       otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
       }
@@ -90,28 +103,29 @@ public:
     SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );
 
     if( HasInputVector )
-    {
+      {
       // Select and Extract samples for training with computed statistics and rates
       ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs );
       ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt );
       SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS );
-    }
+      }
     else
-    {
+      {
       SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC );
-    }
+      }
 
     // Select and Extract samples for validation with computed statistics and rates
     // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided
-    if( dedicatedValidation ) {
-      ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
-      ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
+    if( dedicatedValidation )
+      {
+      ComputePolygonStatistics( imageList, validationVectorFileList, fileNames.polyStatValidOutputs );
+      ComputeSamplingRate( fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv );
       }
-    SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
+    SelectAndExtractValidationSamples( fileNames, imageList, validationVectorFileList );
 
 
     // Then train the model with extracted samples
-    TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs);
+    TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs );
 
     // cleanup
     if( IsParameterEnabled( "cleanup" ) )
@@ -130,63 +144,6 @@ private :
     UpdateInternalParameters( "polystat" );
   }
 
-
-  /**
-   * Retrieve input vector data if provided otherwise generate a default vector shape file for each image.
-   * \param output vector file path
-   * \param fileNames
-   * \return list of input vector data file names
-   */
-  std::vector<std::string> GetVectorFileList(std::string output, TrainFileNamesHandler &fileNames)
-  {
-    std::vector<std::string> vectorFileList;
-    bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
-
-    // Retrieve provided input vector data if available.
-    if( !HasInputVector )
-      {
-      FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
-      unsigned int nbInputs = static_cast<unsigned int>(imageList->Size());
-
-      for( unsigned int i = 0; i < nbInputs; ++i )
-        {
-        std::string name = output + "_vector_" + std::to_string( i ) + ".shp";
-        GenerateVectorDataFile( imageList->GetNthElement( i ), name );
-        fileNames.tmpVectorFileList.push_back( name );
-        }
-      vectorFileList = fileNames.tmpVectorFileList;
-      SetParameterStringList( "io.vd", vectorFileList, false );
-      UpdatePolygonClassStatisticsParameters();
-      GetInternalApplication( "polystat" )->SetParameterString( "field", "fid" );
-      }
-    else
-      {
-      vectorFileList = GetParameterStringList( "io.vd" );
-      }
-
-    return vectorFileList;
-  }
-
-
-
-  void GenerateVectorDataFile(const FloatVectorImageListType::ObjectPointerType &floatVectorImage, std::string name)
-  {
-    typedef otb::ImageToEnvelopeVectorDataFilter<FloatVectorImageType, VectorDataType> ImageToEnvelopeFilterType;
-    typedef ImageToEnvelopeFilterType::OutputVectorDataType OutputVectorData;
-    typedef otb::VectorDataFileWriter<OutputVectorData> VectorDataWriter;
-
-    ImageToEnvelopeFilterType::Pointer imageToEnvelopeVectorData = ImageToEnvelopeFilterType::New();
-    imageToEnvelopeVectorData->SetInput( floatVectorImage );
-    imageToEnvelopeVectorData->SetOutputProjectionRef( floatVectorImage->GetProjectionRef().c_str() );
-    OutputVectorData::Pointer vectorData = imageToEnvelopeVectorData->GetOutput();
-
-    // write temporary generated vector file to disk.
-    VectorDataWriter::Pointer vectorDataFileWriter = VectorDataWriter::New();
-    vectorDataFileWriter->SetInput( vectorData );
-    vectorDataFileWriter->SetFileName( name.c_str() );
-    vectorDataFileWriter->Write();
-  }
-
 };
 
 }
diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
index 4f5dd82d3c..4e9a0cf220 100644
--- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
+++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h
@@ -327,7 +327,6 @@ protected:
                                std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy)
   {
     GetInternalApplication( "select" )->SetParameterInputImage( "in", image );
-    GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false );
     GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false );
 
     GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false );
@@ -338,10 +337,12 @@ protected:
       {
       case GEOMETRIC:
         GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false );
-        GetInternalApplication( "select" )->SetParameterString( "strategy", "all", false );
+        GetInternalApplication( "select" )->SetParameterString( "strategy", "percent", false );
+        GetInternalApplication( "select" )->SetParameterFloat("strategy.percent.p", GetParameterFloat("sample.percent"), false);
         break;
       case CLASS:
       default:
+        GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false );
         GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false );
         GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false );
         GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 );
@@ -365,19 +366,19 @@ protected:
 
     for( unsigned int i = 0; i < imageList->Size(); ++i )
       {
-      SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileNames[i], fileNames.sampleOutputs[i],
+      std::string vectorFileName = vectorFileNames.empty() ? "" : vectorFileNames[i];
+      SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileName, fileNames.sampleOutputs[i],
                                fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy );
       }
   }
 
 
   void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
-                                         const std::vector<std::string> &validationVectorFileList,
-                                         bool dedicatedValidation)
+                                         const std::vector<std::string> &validationVectorFileList)
   {
     // In dedicated validation mode the by class sampling strategy and statistics are used.
     // Otherwise simply split training to validation samples corresponding to sample.vtr percentage.
-    if( dedicatedValidation )
+    if( !validationVectorFileList.empty() )
       {
       for( unsigned int i = 0; i < imageList->Size(); ++i )
         {
-- 
GitLab