From bec02c1fa304d05298e3f10442052fcf9f5444a0 Mon Sep 17 00:00:00 2001
From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr>
Date: Wed, 15 Feb 2017 15:11:51 +0100
Subject: [PATCH] REFAC: Refactoring TrainVectorClassifier.

Inherit TrainVectorClassifier from TrainVectorBase and use Non-Virtual
Function Idiom to provide common behavior for Unsupervised and
Supervised classification.
---
 .../app/otbTrainVectorClassifier.cxx          | 676 +++++-------------
 .../app/otbTrainVectorClustering.cxx          |  64 ++
 .../include/otbTrainVectorBase.h              | 207 ++++++
 .../include/otbTrainVectorBase.txx            | 305 ++++++++
 4 files changed, 772 insertions(+), 480 deletions(-)
 create mode 100644 Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx
 create mode 100644 Modules/Applications/AppClassification/include/otbTrainVectorBase.h
 create mode 100644 Modules/Applications/AppClassification/include/otbTrainVectorBase.txx

diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx
index fa5209552e..61609f080e 100644
--- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx
+++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx
@@ -14,60 +14,30 @@
  PURPOSE.  See the above copyright notices for more information.
 
  =========================================================================*/
-#include "otbWrapperApplication.h"
-#include "otbWrapperApplicationFactory.h"
-
-#include "otbLearningApplicationBase.h"
-
-#include "otbOGRDataSourceWrapper.h"
-#include "otbOGRFeatureWrapper.h"
-#include "otbStatisticsXMLFileWriter.h"
-
-#include "itkVariableLengthVector.h"
-#include "otbStatisticsXMLFileReader.h"
-
-#include "itkListSample.h"
-#include "otbShiftScaleSampleListFilter.h"
+#include "otbTrainVectorBase.h"
 
 // Validation
 #include "otbConfusionMatrixCalculator.h"
 
-#include <algorithm>
-#include <locale>
-
 namespace otb
 {
 namespace Wrapper
 {
 
-/** Utility function to negate std::isalnum */
-bool IsNotAlphaNum(char c)
-  {
-  return !std::isalnum(c);
-  }
-
-class TrainVectorClassifier : public LearningApplicationBase<float,int>
+class TrainVectorClassifier : public TrainVectorBase
 {
 public:
   typedef TrainVectorClassifier Self;
-  typedef LearningApplicationBase<float, int> Superclass;
+  typedef TrainVectorBase Superclass;
   typedef itk::SmartPointer<Self> Pointer;
   typedef itk::SmartPointer<const Self> ConstPointer;
-  itkNewMacro(Self)
-
-  itkTypeMacro(Self, Superclass)
-
-  typedef Superclass::SampleType              SampleType;
-  typedef Superclass::ListSampleType          ListSampleType;
-  typedef Superclass::TargetListSampleType    TargetListSampleType;
-  typedef Superclass::SampleImageType         SampleImageType;
-  
-  typedef double ValueType;
-  typedef itk::VariableLengthVector<ValueType> MeasurementType;
+  itkNewMacro( Self )
 
-  typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
+  itkTypeMacro( Self, Superclass )
 
-  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
+  typedef Superclass::SampleType SampleType;
+  typedef Superclass::ListSampleType ListSampleType;
+  typedef Superclass::TargetListSampleType TargetListSampleType;
 
   // Estimate performance on validation sample
   typedef otb::ConfusionMatrixCalculator<TargetListSampleType, TargetListSampleType> ConfusionMatrixCalculatorType;
@@ -75,503 +45,249 @@ public:
   typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType;
   typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType;
 
+
 private:
-  void DoInit()
+  void DoTrainInit()
   {
-    SetName("TrainVectorClassifier");
-    SetDescription("Train a classifier based on labeled geometries and a list of features to consider.");
-
-    SetDocName("Train Vector Classifier");
-    SetDocLongDescription("This application trains a classifier based on "
-      "labeled geometries and a list of features to consider for classification.");
-    SetDocLimitations(" ");
-    SetDocAuthors("OTB Team");
-    SetDocSeeAlso(" ");
-   
-    //Group IO
-    AddParameter(ParameterType_Group, "io", "Input and output data");
-    SetParameterDescription("io", "This group of parameters allows setting input and output data.");
-
-    AddParameter(ParameterType_InputVectorDataList, "io.vd", "Input Vector Data");
-    SetParameterDescription("io.vd", "Input geometries used for training (note : all geometries from the layer will be used)");
-
-    AddParameter(ParameterType_InputFilename, "io.stats", "Input XML image statistics file");
-    MandatoryOff("io.stats");
-    SetParameterDescription("io.stats", "XML file containing mean and variance of each feature.");
-
-    AddParameter(ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix");
-    SetParameterDescription("io.confmatout", "Output file containing the confusion matrix (.csv format).");
-    MandatoryOff("io.confmatout");
-
-    AddParameter(ParameterType_OutputFilename, "io.out", "Output model");
-    SetParameterDescription("io.out", "Output file containing the model estimated (.txt format).");
-
-    AddParameter(ParameterType_ListView,  "feat", "Field names for training features.");
-    SetParameterDescription("feat","List of field names in the input vector data to be used as features for training.");
-
-    AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision");
-    SetParameterDescription("cfield","Field containing the class id for supervision. "
-      "Only geometries with this field available will be taken into account.");
-    SetListViewSingleSelectionMode("cfield",true);
-      
-    AddParameter(ParameterType_Int, "layer", "Layer Index");
-    SetParameterDescription("layer", "Index of the layer to use in the input vector file.");
-    MandatoryOff("layer");
-    SetDefaultParameterInt("layer",0);
-
-    AddParameter(ParameterType_Group, "valid", "Validation data");
-    SetParameterDescription("valid", "This group of parameters defines validation data.");
-
-    AddParameter(ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data");
-    SetParameterDescription("valid.vd", "Geometries used for validation "
-      "(must contain the same fields used for training, all geometries from the layer will be used)");
-    MandatoryOff("valid.vd");
-
-    AddParameter(ParameterType_Int, "valid.layer", "Layer Index");
-    SetParameterDescription("valid.layer", "Index of the layer to use in the validation vector file.");
-    MandatoryOff("valid.layer");
-    SetDefaultParameterInt("valid.layer",0);
-
-    // Add parameters for the classifier choice
-    Superclass::DoInit();
-
-    AddRANDParameter();
+    SetName( "TrainVectorClassifier" );
+    SetDescription( "Train a classifier based on labeled geometries and a list of features to consider." );
+
+    SetDocName( "Train Vector Classifier" );
+    SetDocLongDescription( "This application trains a classifier based on "
+                                   "labeled geometries and a list of features to consider for classification." );
+    SetDocLimitations( " " );
+    SetDocAuthors( "OTB Team" );
+    SetDocSeeAlso( " " );
+
+    // Add a new parameter to compute confusion matrix
+    AddParameter( ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix" );
+    SetParameterDescription( "io.confmatout", "Output file containing the confusion matrix (.csv format)." );
+    MandatoryOff( "io.confmatout" );
+
     // Doc example parameter settings
-    SetDocExampleParameterValue("io.vd", "vectorData.shp");
-    SetDocExampleParameterValue("io.stats", "meanVar.xml");
-    SetDocExampleParameterValue("io.out", "svmModel.svm");
-    SetDocExampleParameterValue("feat", "perimeter  area  width");
-    SetDocExampleParameterValue("cfield", "predicted");
+    SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
+    SetDocExampleParameterValue( "io.stats", "meanVar.xml" );
+    SetDocExampleParameterValue( "io.out", "svmModel.svm" );
+    SetDocExampleParameterValue( "feat", "perimeter  area  width" );
+    SetDocExampleParameterValue( "cfield", "predicted" );
+
   }
 
-  void DoUpdateParameters()
+  void DoTrainUpdateParameters()
   {
-    if ( HasValue("io.vd") )
-      {
-      std::vector<std::string> vectorFileList = GetParameterStringList("io.vd");
-      ogr::DataSource::Pointer ogrDS =
-        ogr::DataSource::New(vectorFileList[0], ogr::DataSource::Modes::Read);
-      ogr::Layer layer = ogrDS->GetLayer(this->GetParameterInt("layer"));
-      ogr::Feature feature = layer.ogr().GetNextFeature();
-
-      ClearChoices("feat");
-      ClearChoices("cfield");
-      
-      for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
-        {
-        std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
-        key = item;
-        std::string::iterator end = std::remove_if(key.begin(),key.end(),IsNotAlphaNum);
-        std::transform(key.begin(), end, key.begin(), tolower);
-        
-        OGRFieldType fieldType = feature.ogr().GetFieldDefnRef(iField)->GetType();
-        
-        if(fieldType == OFTInteger ||  ogr::version_proxy::IsOFTInteger64(fieldType) || fieldType == OFTReal)
-          {
-          std::string tmpKey="feat."+key.substr(0, end - key.begin());
-          AddChoice(tmpKey,item);
-          }
-        if(fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType))
-          {
-          std::string tmpKey="cfield."+key.substr(0, end - key.begin());
-          AddChoice(tmpKey,item);
-          }
-        }
-      }
+    // Nothing to do here
+  }
+
+  void DoTrainExecute()
+  {
+    ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionmatrix( predictedList,
+                                                                                 classificationListSamples.labeledListSample );
+    WriteConfusionMatrix( confMatCalc );
   }
 
 
-void LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc)
-{
-  ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix();
+  ConfusionMatrixCalculatorType::Pointer
+  ComputeConfusionmatrix(const TargetListSampleType::Pointer &predictedListSample,
+                         const TargetListSampleType::Pointer &performanceLabeledListSample)
+  {
+    ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
+
+    otbAppLogINFO( "Predicted list size : " << predictedListSample->Size() );
+    otbAppLogINFO( "ValidationLabeledListSample size : " << performanceLabeledListSample->Size() );
+    confMatCalc->SetReferenceLabels( performanceLabeledListSample );
+    confMatCalc->SetProducedLabels( predictedListSample );
+    confMatCalc->Compute();
 
-  // Compute minimal width
-  size_t minwidth = 0;
+    otbAppLogINFO( "training performances" );
+    LogConfusionMatrix( confMatCalc );
 
-  for (unsigned int i = 0; i < matrix.Rows(); i++)
-    {
-    for (unsigned int j = 0; j < matrix.Cols(); j++)
+    for( unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++ )
       {
-      std::ostringstream os;
-      os << matrix(i, j);
-      size_t size = os.str().size();
+      ConfusionMatrixCalculatorType::ClassLabelType classLabel = confMatCalc->GetMapOfIndices()[itClasses];
 
-      if (size > minwidth)
-        {
-        minwidth = size;
-        }
+      otbAppLogINFO( "Precision of class [" << classLabel << "] vs all: " << confMatCalc->GetPrecisions()[itClasses] );
+      otbAppLogINFO( "Recall of class    [" << classLabel << "] vs all: " << confMatCalc->GetRecalls()[itClasses] );
+      otbAppLogINFO(
+              "F-score of class   [" << classLabel << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n" );
       }
-    }
+    otbAppLogINFO( "Global performance, Kappa index: " << confMatCalc->GetKappaIndex() );
+    return confMatCalc;
+  }
 
-  MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices();
+  /**
+   * Write the confidence matrix into a file if output is provided.
+   * \param confMatCalc the input matrix to write.
+   */
+  void WriteConfusionMatrix(const ConfusionMatrixCalculatorType::Pointer &confMatCalc)
+  {
+    if( this->HasValue( "io.confmatout" ) )
+      {
+      // Writing the confusion matrix in the output .CSV file
 
-  MapOfIndicesType::const_iterator it = mapOfIndices.begin();
-  MapOfIndicesType::const_iterator end = mapOfIndices.end();
+      MapOfIndicesType::iterator itMapOfIndicesValid, itMapOfIndicesPred;
+      ClassLabelType labelValid = 0;
 
-  for (; it != end; ++it)
-    {
-    std::ostringstream os;
-    os << "[" << it->second << "]";
+      ConfusionMatrixType confusionMatrix = confMatCalc->GetConfusionMatrix();
+      MapOfIndicesType mapOfIndicesValid = confMatCalc->GetMapOfIndices();
 
-    size_t size = os.str().size();
-    if (size > minwidth)
-      {
-      minwidth = size;
-      }
-    }
+      unsigned long nbClassesPred = mapOfIndicesValid.size();
 
-  // Generate matrix string, with 'minwidth' as size specifier
-  std::ostringstream os;
+      /////////////////////////////////////////////
+      // Filling the 2 headers for the output file
+      const std::string commentValidStr = "#Reference labels (rows):";
+      const std::string commentPredStr = "#Produced labels (columns):";
+      const char separatorChar = ',';
+      std::ostringstream ossHeaderValidLabels, ossHeaderPredLabels;
 
-  // Header line
-  for (size_t i = 0; i < minwidth; ++i)
-    os << " ";
-  os << " ";
-
-  it = mapOfIndices.begin();
-  end = mapOfIndices.end();
-  for (; it != end; ++it)
-    {
-    os << "[" << it->second << "]" << " ";
-    }
-
-  os << std::endl;
-
-  // Each line of confusion matrix
-  for (unsigned int i = 0; i < matrix.Rows(); i++)
-    {
-    ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i];
-    os << "[" << std::setw(minwidth - 2) << label << "]" << " ";
-    for (unsigned int j = 0; j < matrix.Cols(); j++)
-      {
-      os << std::setw(minwidth) << matrix(i, j) << " ";
-      }
-    os << std::endl;
-    }
+      // Filling ossHeaderValidLabels and ossHeaderPredLabels for the output file
+      ossHeaderValidLabels << commentValidStr;
+      ossHeaderPredLabels << commentPredStr;
 
-  otbAppLogINFO("Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str());
-}
+      itMapOfIndicesValid = mapOfIndicesValid.begin();
 
+      while( itMapOfIndicesValid != mapOfIndicesValid.end() )
+        {
+        // labels labelValid of mapOfIndicesValid are already sorted in otbConfusionMatrixCalculator
+        labelValid = itMapOfIndicesValid->second;
 
-void DoExecute()
-  {
-  typedef int LabelPixelType;
-  typedef itk::FixedArray<LabelPixelType,1> LabelSampleType;
-  typedef itk::Statistics::ListSample <LabelSampleType> LabelListSampleType;
-
-  // Prepare selected field names (their position may change between two inputs)
-  std::vector<int> selectedIdx = GetSelectedItems("feat");
-  std::vector<int> selectedCFieldIdx = GetSelectedItems("cfield");
-  
-  if(selectedIdx.empty())
-    {
-    otbAppLogFATAL(<<"No features have been selected to train the classifier on!");
-    }
-  
-  if(selectedCFieldIdx.empty())
-    {
-    otbAppLogFATAL(<<"No field has been selected for data labelling!");
-    }
-
-  const unsigned int nbFeatures = selectedIdx.size();
-  std::vector<std::string> fieldNames = GetChoiceNames("feat");
-  std::vector<std::string> cFieldNames = GetChoiceNames("cfield");
-  std::vector<std::string> selectedNames(nbFeatures);
-  for (unsigned int i=0 ; i<nbFeatures ; i++)
-    {
-    selectedNames[i] = fieldNames[selectedIdx[i]];
-    }
-
-  std::string selectedCFieldName = cFieldNames[selectedCFieldIdx.front()];
-
-  std::vector<int> featureFieldIndex(nbFeatures, -1);
-  int cFieldIndex = -1;
-
-  // Statistics for shift/scale
-  MeasurementType meanMeasurementVector;
-  MeasurementType stddevMeasurementVector;
-  if (HasValue("io.stats") && IsParameterEnabled("io.stats"))
-    {
-    StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
-    std::string XMLfile = GetParameterString("io.stats");
-    statisticsReader->SetFileName(XMLfile);
-    meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
-    stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
-    }
-  else
-    {
-    meanMeasurementVector.SetSize(nbFeatures);
-    meanMeasurementVector.Fill(0.);
-    stddevMeasurementVector.SetSize(nbFeatures);
-    stddevMeasurementVector.Fill(1.);
-    }
-
-  ListSampleType::Pointer input = ListSampleType::New();
-  LabelListSampleType::Pointer target = LabelListSampleType::New();
-  input->SetMeasurementVectorSize(nbFeatures);
-
-  std::vector<std::string> vectorFileList = GetParameterStringList("io.vd");
-  for (unsigned int k=0 ; k<vectorFileList.size() ; k++)
-    {
-    otbAppLogINFO("Reading input vector file "<<k+1<<"/"<<vectorFileList.size());
-    ogr::DataSource::Pointer source = ogr::DataSource::New(vectorFileList[k], ogr::DataSource::Modes::Read);
-    ogr::Layer layer = source->GetLayer(this->GetParameterInt("layer"));
-    ogr::Feature feature = layer.ogr().GetNextFeature();
-    bool goesOn = feature.addr() != 0;
-    if (!goesOn)
-      {
-      otbAppLogWARNING("The layer "<<GetParameterInt("layer")<<" of "
-        <<vectorFileList[k]<<" is empty, input is skipped.");
-      continue;
-      }
+        otbAppLogINFO( "mapOfIndicesValid[" << itMapOfIndicesValid->first << "] = " << labelValid );
 
-    // Check all needed fields are present :
-    //   - check class field
-    cFieldIndex = feature.ogr().GetFieldIndex(selectedCFieldName.c_str());
-    if (cFieldIndex < 0)
-      otbAppLogFATAL("The field name for class label ("<<selectedCFieldName
-        <<") has not been found in the input vector file "<<vectorFileList[k]);
-    //   - check feature fields
-    for (unsigned int i=0 ; i<nbFeatures ; i++)
-      {
-      featureFieldIndex[i] = feature.ogr().GetFieldIndex(selectedNames[i].c_str());
-      if (featureFieldIndex[i] < 0)
-        otbAppLogFATAL("The field name for feature "<<selectedNames[i]
-        <<" has not been found in the input vector file "<<vectorFileList[k]);
-      }
+        ossHeaderValidLabels << labelValid;
+        ossHeaderPredLabels << labelValid;
 
-    while(goesOn)
-      {
-      if(feature.ogr().IsFieldSet(cFieldIndex))
-        {
-        MeasurementType mv;
-        mv.SetSize(nbFeatures);
-        for(unsigned int idx=0; idx < nbFeatures; ++idx)
-          mv[idx] = feature.ogr().GetFieldAsDouble(featureFieldIndex[idx]);
+        ++itMapOfIndicesValid;
 
-        input->PushBack(mv);
-        target->PushBack(feature.ogr().GetFieldAsInteger(cFieldIndex));
-        }
-      feature = layer.ogr().GetNextFeature();
-      goesOn = feature.addr() != 0;
-      }
-    }
-
-  ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
-  trainingShiftScaleFilter->SetInput(input);
-  trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
-  trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
-  trainingShiftScaleFilter->Update();
-
-  ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput();
-  TargetListSampleType::Pointer trainingLabeledListSample = target;
-
-  //--------------------------
-  // Estimate model
-  //--------------------------
-  this->Train(trainingListSample,trainingLabeledListSample,GetParameterString("io.out"));
-
-  //--------------------------
-  // Performances estimation
-  //--------------------------
-  ListSampleType::Pointer validationListSample=ListSampleType::New();
-  TargetListSampleType::Pointer validationLabeledListSample = TargetListSampleType::New();
-
-  // Import validation data
-  if (HasValue("valid.vd") && IsParameterEnabled("valid.vd"))
-    {
-    input = ListSampleType::New();
-    target = LabelListSampleType::New();
-    input->SetMeasurementVectorSize(nbFeatures);
-
-    std::vector<std::string> validFileList = this->GetParameterStringList("valid.vd");
-    for (unsigned int k=0 ; k<validFileList.size() ; k++)
-      {
-      otbAppLogINFO("Reading validation vector file "<<k+1<<"/"<<validFileList.size());
-      ogr::DataSource::Pointer source = ogr::DataSource::New(validFileList[k], ogr::DataSource::Modes::Read);
-      ogr::Layer layer = source->GetLayer(this->GetParameterInt("valid.layer"));
-      ogr::Feature feature = layer.ogr().GetNextFeature();
-      bool goesOn = feature.addr() != 0;
-      if (!goesOn)
-        {
-        otbAppLogWARNING("The layer "<<GetParameterInt("valid.layer")<<" of "
-          <<validFileList[k]<<" is empty, input is skipped.");
-        continue;
+        if( itMapOfIndicesValid != mapOfIndicesValid.end() )
+          {
+          ossHeaderValidLabels << separatorChar;
+          ossHeaderPredLabels << separatorChar;
+          }
+        else
+          {
+          ossHeaderValidLabels << std::endl;
+          ossHeaderPredLabels << std::endl;
+          }
         }
 
-      // Check all needed fields are present :
-      //   - check class field
-      cFieldIndex = feature.ogr().GetFieldIndex(selectedCFieldName.c_str());
-      if (cFieldIndex < 0)
-        otbAppLogFATAL("The field name for class label ("<<selectedCFieldName
-          <<") has not been found in the validation vector file "<<validFileList[k]);
-      //   - check feature fields
-      for (unsigned int i=0 ; i<nbFeatures ; i++)
-        {
-        featureFieldIndex[i] = feature.ogr().GetFieldIndex(selectedNames[i].c_str());
-        if (featureFieldIndex[i] < 0)
-          otbAppLogFATAL("The field name for feature "<<selectedNames[i]
-          <<" has not been found in the validation vector file "<<validFileList[k]);
-        }
+      std::ofstream outFile;
+      outFile.open( this->GetParameterString( "io.confmatout" ).c_str() );
+      outFile << std::fixed;
+      outFile.precision( 10 );
+
+      /////////////////////////////////////
+      // Writing the 2 headers
+      outFile << ossHeaderValidLabels.str();
+      outFile << ossHeaderPredLabels.str();
+      /////////////////////////////////////
+
+      unsigned int indexLabelValid = 0, indexLabelPred = 0;
 
-      while(goesOn)
+      for( itMapOfIndicesValid = mapOfIndicesValid.begin();
+           itMapOfIndicesValid != mapOfIndicesValid.end(); ++itMapOfIndicesValid )
         {
-        if(feature.ogr().IsFieldSet(cFieldIndex))
-          {
-          MeasurementType mv;
-          mv.SetSize(nbFeatures);
-          for(unsigned int idx=0; idx < nbFeatures; ++idx)
-            mv[idx] = feature.ogr().GetFieldAsDouble(featureFieldIndex[idx]);
+        indexLabelPred = 0;
 
-          input->PushBack(mv);
-          target->PushBack(feature.ogr().GetFieldAsInteger(cFieldIndex));
+        for( itMapOfIndicesPred = mapOfIndicesValid.begin();
+             itMapOfIndicesPred != mapOfIndicesValid.end(); ++itMapOfIndicesPred )
+          {
+          // Writing the confusion matrix (sorted in otbConfusionMatrixCalculator) in the output file
+          outFile << confusionMatrix( indexLabelValid, indexLabelPred );
+          if( indexLabelPred < ( nbClassesPred - 1 ) )
+            {
+            outFile << separatorChar;
+            }
+          else
+            {
+            outFile << std::endl;
+            }
+          ++indexLabelPred;
           }
-        feature = layer.ogr().GetNextFeature();
-        goesOn = feature.addr() != 0;
+
+        ++indexLabelValid;
         }
+
+      outFile.close();
       }
+  }
+
+  /**
+   * Display the log of the confusion matrix computed with
+   * \param confMatCalc the input confusion matrix to display
+   */
+  void LogConfusionMatrix(ConfusionMatrixCalculatorType *confMatCalc)
+  {
+    ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix();
+
+    // Compute minimal width
+    size_t minwidth = 0;
 
-    ShiftScaleFilterType::Pointer validShiftScaleFilter = ShiftScaleFilterType::New();
-    validShiftScaleFilter->SetInput(input);
-    validShiftScaleFilter->SetShifts(meanMeasurementVector);
-    validShiftScaleFilter->SetScales(stddevMeasurementVector);
-    validShiftScaleFilter->Update();
-
-    validationListSample = validShiftScaleFilter->GetOutput();
-    validationLabeledListSample = target;
-    }
- 
-  //Test the input validation set size
-  TargetListSampleType::Pointer predictedList = TargetListSampleType::New();
-  ListSampleType::Pointer performanceListSample;
-  TargetListSampleType::Pointer performanceLabeledListSample;
-  if(validationLabeledListSample->Size() != 0)
-    {
-    performanceListSample = validationListSample;
-    performanceLabeledListSample = validationLabeledListSample;
-    }
-  else
-    {
-    otbAppLogWARNING("The validation set is empty. The performance estimation is done using the input training set in this case.");
-    performanceListSample = trainingListSample;
-    performanceLabeledListSample = trainingLabeledListSample;
-    }
-
-  this->Classify(performanceListSample, predictedList, GetParameterString("io.out"));
-
-  ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
-
-  otbAppLogINFO("Predicted list size : " << predictedList->Size());
-  otbAppLogINFO("ValidationLabeledListSample size : " << performanceLabeledListSample->Size());
-  confMatCalc->SetReferenceLabels(performanceLabeledListSample);
-  confMatCalc->SetProducedLabels(predictedList);
-  confMatCalc->Compute();
-
-  otbAppLogINFO("training performances");
-  LogConfusionMatrix(confMatCalc);
-
-  for (unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++)
-    {
-    ConfusionMatrixCalculatorType::ClassLabelType classLabel = confMatCalc->GetMapOfIndices()[itClasses];
-
-    otbAppLogINFO("Precision of class [" << classLabel << "] vs all: " << confMatCalc->GetPrecisions()[itClasses]);
-    otbAppLogINFO("Recall of class    [" << classLabel << "] vs all: " << confMatCalc->GetRecalls()[itClasses]);
-    otbAppLogINFO(
-      "F-score of class   [" << classLabel << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n");
-    }
-  otbAppLogINFO("Global performance, Kappa index: " << confMatCalc->GetKappaIndex());
-
-
-  if (this->HasValue("io.confmatout"))
-    {
-    // Writing the confusion matrix in the output .CSV file
-
-    MapOfIndicesType::iterator itMapOfIndicesValid, itMapOfIndicesPred;
-    ClassLabelType labelValid = 0;
-
-    ConfusionMatrixType confusionMatrix = confMatCalc->GetConfusionMatrix();
-    MapOfIndicesType mapOfIndicesValid = confMatCalc->GetMapOfIndices();
-
-    unsigned int nbClassesPred = mapOfIndicesValid.size();
-
-    /////////////////////////////////////////////
-    // Filling the 2 headers for the output file
-    const std::string commentValidStr = "#Reference labels (rows):";
-    const std::string commentPredStr = "#Produced labels (columns):";
-    const char separatorChar = ',';
-    std::ostringstream ossHeaderValidLabels, ossHeaderPredLabels;
-
-    // Filling ossHeaderValidLabels and ossHeaderPredLabels for the output file
-    ossHeaderValidLabels << commentValidStr;
-    ossHeaderPredLabels << commentPredStr;
-
-    itMapOfIndicesValid = mapOfIndicesValid.begin();
-
-    while (itMapOfIndicesValid != mapOfIndicesValid.end())
+    for( unsigned int i = 0; i < matrix.Rows(); i++ )
       {
-      // labels labelValid of mapOfIndicesValid are already sorted in otbConfusionMatrixCalculator
-      labelValid = itMapOfIndicesValid->second;
+      for( unsigned int j = 0; j < matrix.Cols(); j++ )
+        {
+        std::ostringstream os;
+        os << matrix( i, j );
+        size_t size = os.str().size();
 
-      otbAppLogINFO("mapOfIndicesValid[" << itMapOfIndicesValid->first << "] = " << labelValid);
+        if( size > minwidth )
+          {
+          minwidth = size;
+          }
+        }
+      }
 
-      ossHeaderValidLabels << labelValid;
-      ossHeaderPredLabels << labelValid;
+    MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices();
 
-      ++itMapOfIndicesValid;
+    MapOfIndicesType::const_iterator it = mapOfIndices.begin();
+    MapOfIndicesType::const_iterator end = mapOfIndices.end();
 
-      if (itMapOfIndicesValid != mapOfIndicesValid.end())
-        {
-        ossHeaderValidLabels << separatorChar;
-        ossHeaderPredLabels << separatorChar;
-        }
-      else
+    for( ; it != end; ++it )
+      {
+      std::ostringstream os;
+      os << "[" << it->second << "]";
+
+      size_t size = os.str().size();
+      if( size > minwidth )
         {
-        ossHeaderValidLabels << std::endl;
-        ossHeaderPredLabels << std::endl;
+        minwidth = size;
         }
       }
 
-    std::ofstream outFile;
-    outFile.open(this->GetParameterString("io.confmatout").c_str());
-    outFile << std::fixed;
-    outFile.precision(10);
-
-    /////////////////////////////////////
-    // Writing the 2 headers
-    outFile << ossHeaderValidLabels.str();
-    outFile << ossHeaderPredLabels.str();
-    /////////////////////////////////////
+    // Generate matrix string, with 'minwidth' as size specifier
+    std::ostringstream os;
 
-    unsigned int indexLabelValid = 0, indexLabelPred = 0;
+    // Header line
+    for( size_t i = 0; i < minwidth; ++i )
+      os << " ";
+    os << " ";
 
-    for (itMapOfIndicesValid = mapOfIndicesValid.begin(); itMapOfIndicesValid != mapOfIndicesValid.end(); ++itMapOfIndicesValid)
+    it = mapOfIndices.begin();
+    end = mapOfIndices.end();
+    for( ; it != end; ++it )
       {
-      indexLabelPred = 0;
+      os << "[" << it->second << "]" << " ";
+      }
+
+    os << std::endl;
 
-      for (itMapOfIndicesPred = mapOfIndicesValid.begin(); itMapOfIndicesPred != mapOfIndicesValid.end(); ++itMapOfIndicesPred)
+    // Each line of confusion matrix
+    for( unsigned int i = 0; i < matrix.Rows(); i++ )
+      {
+      ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i];
+      os << "[" << std::setw( minwidth - 2 ) << label << "]" << " ";
+      for( unsigned int j = 0; j < matrix.Cols(); j++ )
         {
-        // Writing the confusion matrix (sorted in otbConfusionMatrixCalculator) in the output file
-        outFile << confusionMatrix(indexLabelValid, indexLabelPred);
-        if (indexLabelPred < (nbClassesPred - 1))
-          {
-          outFile << separatorChar;
-          }
-        else
-          {
-          outFile << std::endl;
-          }
-        ++indexLabelPred;
+        os << std::setw( minwidth ) << matrix( i, j ) << " ";
         }
-
-      ++indexLabelValid;
+      os << std::endl;
       }
 
-    outFile.close();
-    } // END if (this->HasValue("io.confmatout"))
+    otbAppLogINFO( "Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str() );
   }
 
+
 };
 }
 }
 
-OTB_APPLICATION_EXPORT(otb::Wrapper::TrainVectorClassifier)
+OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorClassifier )
diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx
new file mode 100644
index 0000000000..eec6927bd0
--- /dev/null
+++ b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx
@@ -0,0 +1,64 @@
+/*=========================================================================
+ Program:   ORFEO Toolbox
+ Language:  C++
+ Date:      $Date$
+ Version:   $Revision$
+
+
+ Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+ See OTBCopyright.txt for details.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE.  See the above copyright notices for more information.
+
+ =========================================================================*/
+#include "otbTrainVectorBase.h"
+
+// Validation
+#include "otbConfusionMatrixCalculator.h"
+
+namespace otb
+{
+namespace Wrapper
+{
+
+class TrainVectorClassifier : public TrainVectorBase
+{
+public:
+  typedef TrainVectorClassifier Self;
+  typedef TrainVectorBase Superclass;
+  typedef itk::SmartPointer<Self> Pointer;
+  typedef itk::SmartPointer<const Self> ConstPointer;
+  itkNewMacro( Self )
+
+  itkTypeMacro( Self, Superclass )
+
+  typedef Superclass::SampleType SampleType;
+  typedef Superclass::ListSampleType ListSampleType;
+  typedef Superclass::TargetListSampleType TargetListSampleType;
+
+private:
+  void DoTrainInit()
+  {
+    // Nothing to do here
+  }
+
+  void DoTrainUpdateParameters()
+  {
+    // Nothing to do here
+  }
+
+  void DoTrainExecute()
+  {
+    // Nothing to do here
+  }
+
+
+
+};
+}
+}
+
+OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorClustering )
diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h
new file mode 100644
index 0000000000..0bff79f538
--- /dev/null
+++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h
@@ -0,0 +1,207 @@
+/*=========================================================================
+ Program:   ORFEO Toolbox
+ Language:  C++
+ Date:      $Date$
+ Version:   $Revision$
+
+
+ Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+ See OTBCopyright.txt for details.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE.  See the above copyright notices for more information.
+
+ =========================================================================*/
+#ifndef otbTrainVectorBase_h
+#define otbTrainVectorBase_h
+
+#include "otbLearningApplicationBase.h"
+#include "otbWrapperApplication.h"
+#include "otbWrapperApplicationFactory.h"
+
+#include "otbOGRDataSourceWrapper.h"
+#include "otbOGRFeatureWrapper.h"
+#include "otbStatisticsXMLFileWriter.h"
+
+#include "itkVariableLengthVector.h"
+#include "otbStatisticsXMLFileReader.h"
+
+#include "itkListSample.h"
+#include "otbShiftScaleSampleListFilter.h"
+
+#include <algorithm>
+#include <locale>
+
+namespace otb
+{
+namespace Wrapper
+{
+
+/** Utility function to negate std::isalnum */
+bool IsNotAlphaNum(char c)
+{
+  return !std::isalnum( c );
+}
+
+class TrainVectorBase : public LearningApplicationBase<float, int>
+{
+public:
+  typedef TrainVectorBase Self;
+  typedef LearningApplicationBase<float, int> Superclass;
+  typedef itk::SmartPointer <Self> Pointer;
+  typedef itk::SmartPointer<const Self> ConstPointer;
+
+  itkTypeMacro(Self, Superclass)
+
+  typedef Superclass::SampleType SampleType;
+  typedef Superclass::ListSampleType ListSampleType;
+  typedef Superclass::TargetListSampleType TargetListSampleType;
+
+  typedef double ValueType;
+  typedef itk::VariableLengthVector <ValueType> MeasurementType;
+
+  typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
+
+  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
+
+protected:
+
+  /** Class used to store statistics Measurment (mean/stddev) */
+  class StatisticsMeasurement
+  {
+  public:
+    MeasurementType meanMeasurementVector;
+    MeasurementType stddevMeasurementVector;
+  };
+
+  /** Class used to store a list of sample and the corresponding label */
+  class ListSamples
+  {
+  public:
+    ListSampleType::Pointer listSample;
+    TargetListSampleType::Pointer labeledListSample;
+    ListSamples()
+    {
+      listSample = ListSampleType::New();
+      labeledListSample = TargetListSampleType::New();
+    }
+  };
+
+  /**
+   * Features information class used to store informations
+   * about the field and class name/id of an input vector
+   */
+  class FeaturesInfo
+  {
+  public:
+    /** Index for class field */
+    std::vector<int> m_SelectedCFieldIdx;
+    /** Selected Index */
+    std::vector<int> m_SelectedIdx;
+    /** Selected class field name */
+    std::string m_SelectedCFieldName;
+    /** Selected names */
+    std::vector <std::string> m_SelectedNames;
+    unsigned int m_NbFeatures;
+
+    FeaturesInfo(std::vector <std::string> fieldNames, std::vector <std::string> cFieldNames,
+                 std::vector<int> selectedIdx, std::vector<int> selectedCFieldIdx)
+            : m_SelectedIdx( selectedIdx ), m_SelectedCFieldIdx( selectedCFieldIdx )
+    {
+      m_NbFeatures = static_cast<unsigned int>(selectedIdx.size());
+      m_SelectedNames = std::vector<std::string>( m_NbFeatures );
+      for( unsigned int i = 0; i < m_NbFeatures; ++i )
+        {
+        m_SelectedNames[i] = fieldNames[selectedIdx[i]];
+        }
+
+      m_SelectedCFieldName = cFieldNames[selectedCFieldIdx.front()];
+
+    }
+  };
+
+
+protected:
+
+  /**
+   * Function which extract and store all samples for Training, Classification and Validation.
+   * \param measurement statics measurement (mean/stddev)
+   * \param featuresInfo information about the features
+   * \return sample list used for training
+   */
+  virtual void ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
+
+  /**
+  * Extract the training sample list
+  * \param measurement statics measurement (mean/stddev)
+  * \param featuresInfo information about the features
+  * \return sample list used for training
+  */
+  virtual ListSamples ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
+
+  /**
+  * Extract the validation sample list
+  * \param measurement statics measurement (mean/stddev)
+  * \param featuresInfo information about the features
+  * \return sample list used for validation
+  */
+  virtual ListSamples ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
+
+  /**
+   * Extract the sample list classification
+   * \param measurement statics measurement (mean/stddev)
+   * \param featuresInfo information about the features
+   * \return sample list used for classification
+   */
+  virtual ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
+
+  ListSamples trainingListSamples;
+  ListSamples validationListSamples;
+  ListSamples classificationListSamples;
+  TargetListSampleType::Pointer predictedList;
+
+private:
+  virtual void DoTrainInit() = 0;
+  virtual void DoTrainExecute() = 0;
+  virtual void DoTrainUpdateParameters() = 0;
+
+  void DoInit();
+  void DoUpdateParameters();
+  void DoExecute();
+
+  /** Extract samples from input file for corresponding field name
+ *
+ * \param parameterName the name of the input file option in the input application parameters
+ * \param parameterLayer the name of the layer option in the input application parameters
+ * \param measurement statics measurement (mean/stddev)
+ * \param nbFeatures the number of features.
+ * \return the list of samples and their corresponding labels.
+ */
+  ListSamples ExtractListSamples(std::string parameterName, std::string parameterLayer,
+                                                  const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo);
+
+
+
+  ListSamples ExtractClassificationListSamples(ListSamples &validationListSamples, ListSamples &trainingListSamples);
+
+
+  /**
+  * Retrieve statistics mean and standard deviation if input statistics are provided.
+  * Otherwise mean is set to 0 and standard deviation to 1 for each Features.
+  * \param nbFeatures
+  */
+  StatisticsMeasurement ComputeStatistics(unsigned int nbFeatures);
+
+};
+
+}
+}
+
+#ifndef OTB_MANUAL_INSTANTIATION
+#include "otbTrainVectorBase.txx"
+#endif
+
+#endif
+
diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx
new file mode 100644
index 0000000000..c45d51ad78
--- /dev/null
+++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx
@@ -0,0 +1,305 @@
+/*=========================================================================
+ Program:   ORFEO Toolbox
+ Language:  C++
+ Date:      $Date$
+ Version:   $Revision$
+
+
+ Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+ See OTBCopyright.txt for details.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE.  See the above copyright notices for more information.
+
+ =========================================================================*/
+#ifndef otbTrainVectorBase_txx
+#define otbTrainVectorBase_txx
+
+#include "otbTrainVectorBase.h"
+
+namespace otb
+{
+namespace Wrapper
+{
+
+void TrainVectorBase::DoInit()
+{
+  SetName( "TrainVectorClassifier" );
+  SetDescription( "Train a classifier based on labeled geometries and a list of features to consider." );
+
+  SetDocName( "Train Vector Classifier" );
+  SetDocLongDescription( "This application trains a classifier based on "
+                                 "labeled geometries and a list of features to consider for classification." );
+  SetDocLimitations( " " );
+  SetDocAuthors( "OTB Team" );
+  SetDocSeeAlso( " " );
+
+  // Common Parameters for all Learning Application
+  AddParameter( ParameterType_Group, "io", "Input and output data" );
+  SetParameterDescription( "io", "This group of parameters allows setting input and output data." );
+
+  AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data" );
+  SetParameterDescription( "io.vd",
+                           "Input geometries used for training (note : all geometries from the layer will be used)" );
+
+  AddParameter( ParameterType_InputFilename, "io.stats", "Input XML image statistics file" );
+  MandatoryOff( "io.stats" );
+  SetParameterDescription( "io.stats", "XML file containing mean and variance of each feature." );
+
+  AddParameter( ParameterType_OutputFilename, "io.out", "Output model" );
+  SetParameterDescription( "io.out", "Output file containing the model estimated (.txt format)." );
+
+  AddParameter( ParameterType_Int, "layer", "Layer Index" );
+  SetParameterDescription( "layer", "Index of the layer to use in the input vector file." );
+  MandatoryOff( "layer" );
+  SetDefaultParameterInt( "layer", 0 );
+
+  //Can be in both Supervised and Unsupervised ?
+  AddParameter( ParameterType_Group, "valid", "Validation data" );
+  SetParameterDescription( "valid", "This group of parameters defines validation data." );
+
+  AddParameter( ParameterType_InputVectorDataList, "valid.vd", "Validation Vector Data" );
+  SetParameterDescription( "valid.vd", "Geometries used for validation "
+          "(must contain the same fields used for training, all geometries from the layer will be used)" );
+  MandatoryOff( "valid.vd" );
+
+  AddParameter( ParameterType_Int, "valid.layer", "Layer Index" );
+  SetParameterDescription( "valid.layer", "Index of the layer to use in the validation vector file." );
+  MandatoryOff( "valid.layer" );
+  SetDefaultParameterInt( "valid.layer", 0 );
+
+  AddParameter(ParameterType_ListView,  "feat", "Field names for training features.");
+  SetParameterDescription("feat","List of field names in the input vector data to be used as features for training.");
+
+  AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision");
+  SetParameterDescription("cfield","Field containing the class id for supervision. "
+          "Only geometries with this field available will be taken into account.");
+  SetListViewSingleSelectionMode("cfield",true);
+
+  // Add parameters for the classifier choice
+  Superclass::DoInit();
+
+  AddRANDParameter();
+
+  DoTrainInit();
+}
+
+void TrainVectorBase::DoUpdateParameters()
+{
+  if( HasValue( "io.vd" ) )
+    {
+    std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
+    ogr::DataSource::Pointer ogrDS = ogr::DataSource::New( vectorFileList[0], ogr::DataSource::Modes::Read );
+    ogr::Layer layer = ogrDS->GetLayer( this->GetParameterInt( "layer" ) );
+    ogr::Feature feature = layer.ogr().GetNextFeature();
+
+    ClearChoices( "feat" );
+    ClearChoices( "cfield" );
+
+    for( int iField = 0; iField < feature.ogr().GetFieldCount(); iField++ )
+      {
+      std::string key, item = feature.ogr().GetFieldDefnRef( iField )->GetNameRef();
+      key = item;
+      std::string::iterator end = std::remove_if( key.begin(), key.end(), IsNotAlphaNum );
+      std::transform( key.begin(), end, key.begin(), tolower );
+
+      OGRFieldType fieldType = feature.ogr().GetFieldDefnRef( iField )->GetType();
+
+      if( fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) || fieldType == OFTReal )
+        {
+        std::string tmpKey = "feat." + key.substr( 0, end - key.begin() );
+        AddChoice( tmpKey, item );
+        }
+      if( fieldType == OFTString || fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64( fieldType ) )
+        {
+        std::string tmpKey = "cfield." + key.substr( 0, end - key.begin() );
+        AddChoice( tmpKey, item );
+        }
+      }
+    }
+
+  DoTrainUpdateParameters();
+}
+
+void TrainVectorBase::DoExecute()
+{
+  typedef int LabelPixelType;
+  typedef itk::FixedArray<LabelPixelType, 1> LabelSampleType;
+  typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType;
+
+  FeaturesInfo featuresInfo( GetChoiceNames( "feat" ), GetChoiceNames( "cfield" ), GetSelectedItems( "feat" ),
+                             GetSelectedItems( "cfield" ) );
+
+  // Check input parameters
+  if( featuresInfo.m_SelectedIdx.empty() )
+    {
+    otbAppLogFATAL( << "No features have been selected to train the classifier on!" );
+    }
+
+  // Todo only Log warning and set CFieldName to 0, 1, 2, 3... (default behavior)
+  if( featuresInfo.m_SelectedCFieldIdx.empty() )
+    {
+    otbAppLogFATAL( << "No field has been selected for data labelling!" );
+    }
+
+  StatisticsMeasurement measurement = ComputeStatistics( featuresInfo.m_NbFeatures );
+  ExtractSamples(measurement, featuresInfo);
+
+  this->Train( trainingListSamples.listSample, trainingListSamples.labeledListSample, GetParameterString( "io.out" ) );
+
+  predictedList = TargetListSampleType::New();
+  this->Classify( classificationListSamples.listSample, predictedList, GetParameterString( "io.out" ) );
+
+  DoTrainExecute();
+}
+
+
+void TrainVectorBase::ExtractSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
+{
+  trainingListSamples = ExtractTrainingListSamples(measurement, featuresInfo);
+  validationListSamples = ExtractValidationListSamples(measurement, featuresInfo);
+  classificationListSamples = ExtractClassificationListSamples(measurement, featuresInfo);
+}
+
+TrainVectorBase::ListSamples
+TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
+{
+  return ExtractListSamples( "io.vd", "layer", measurement, featuresInfo );
+}
+
+TrainVectorBase::ListSamples
+TrainVectorBase::ExtractValidationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
+{
+  return ExtractListSamples( "valid.vd", "valid.layer", measurement, featuresInfo );
+}
+
+
+TrainVectorBase::ListSamples
+TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
+{
+  ListSamples performanceSample;
+
+  //Test the input validation set size
+  if( validationListSamples.labeledListSample->Size() != 0 )
+    {
+    performanceSample.listSample = validationListSamples.listSample;
+    performanceSample.labeledListSample = validationListSamples.labeledListSample;
+    }
+  else
+    {
+    otbAppLogWARNING(
+            "The validation set is empty. The performance estimation is done using the input training set in this case." );
+    performanceSample.listSample = trainingListSamples.listSample;
+    performanceSample.labeledListSample = trainingListSamples.labeledListSample;
+    }
+
+  return performanceSample;
+}
+
+
+TrainVectorBase::StatisticsMeasurement
+TrainVectorBase::ComputeStatistics(unsigned int nbFeatures)
+{
+  StatisticsMeasurement measurement = StatisticsMeasurement();
+  if( HasValue( "io.stats" ) && IsParameterEnabled( "io.stats" ) )
+    {
+    StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
+    std::string XMLfile = GetParameterString( "io.stats" );
+    statisticsReader->SetFileName( XMLfile.c_str() );
+    measurement.meanMeasurementVector = statisticsReader->GetStatisticVectorByName( "mean" );
+    measurement.stddevMeasurementVector = statisticsReader->GetStatisticVectorByName( "stddev" );
+    }
+  else
+    {
+    measurement.meanMeasurementVector.SetSize( nbFeatures );
+    measurement.meanMeasurementVector.Fill( 0. );
+    measurement.stddevMeasurementVector.SetSize( nbFeatures );
+    measurement.stddevMeasurementVector.Fill( 1. );
+    }
+  return measurement;
+}
+
+
+TrainVectorBase::ListSamples
+TrainVectorBase::ExtractListSamples(std::string parameterName, std::string parameterLayer,
+                                    const StatisticsMeasurement &measurement, FeaturesInfo &featuresInfo)
+{
+  ListSamples listSamples;
+  if( HasValue( parameterName ) && IsParameterEnabled( parameterName ) )
+    {
+    ListSampleType::Pointer input = ListSampleType::New();
+    TargetListSampleType::Pointer target = TargetListSampleType::New();
+    input->SetMeasurementVectorSize( featuresInfo.m_NbFeatures );
+
+    std::vector<std::string> validFileList = this->GetParameterStringList( parameterName );
+    for( unsigned int k = 0; k < validFileList.size(); k++ )
+      {
+      otbAppLogINFO( "Reading validation vector file " << k + 1 << "/" << validFileList.size() );
+      ogr::DataSource::Pointer source = ogr::DataSource::New( validFileList[k], ogr::DataSource::Modes::Read );
+      ogr::Layer layer = source->GetLayer( static_cast<size_t>(this->GetParameterInt( parameterLayer )) );
+      ogr::Feature feature = layer.ogr().GetNextFeature();
+      bool goesOn = feature.addr() != 0;
+      if( !goesOn )
+        {
+        otbAppLogWARNING( "The layer " << GetParameterInt( parameterLayer ) << " of " << validFileList[k]
+                                       << " is empty, input is skipped." );
+        continue;
+        }
+
+      // Check all needed fields are present :
+      //   - check class field
+      int cFieldIndex = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedCFieldName.c_str() );
+      if( cFieldIndex < 0 )
+        otbAppLogFATAL( "The field name for class label (" << featuresInfo.m_SelectedCFieldName
+                                                           << ") has not been found in the vector file "
+                                                           << validFileList[k] );
+      //   - check feature fields
+      std::vector<int> featureFieldIndex( featuresInfo.m_NbFeatures, -1 );
+      for( unsigned int i = 0; i < featuresInfo.m_NbFeatures; i++ )
+        {
+        featureFieldIndex[i] = feature.ogr().GetFieldIndex( featuresInfo.m_SelectedNames[i].c_str() );
+        if( featureFieldIndex[i] < 0 )
+          otbAppLogFATAL( "The field name for feature " << featuresInfo.m_SelectedNames[i]
+                                                        << " has not been found in the vector file "
+                                                        << validFileList[k] );
+        }
+
+      while( goesOn )
+        {
+        if( feature.ogr().IsFieldSet( cFieldIndex ) )
+          {
+          MeasurementType mv;
+          mv.SetSize( featuresInfo.m_NbFeatures );
+          for( unsigned int idx = 0; idx < featuresInfo.m_NbFeatures; ++idx )
+            mv[idx] = feature.ogr().GetFieldAsDouble( featureFieldIndex[idx] );
+
+          input->PushBack( mv );
+          target->PushBack( feature.ogr().GetFieldAsInteger( cFieldIndex ) );
+          }
+        feature = layer.ogr().GetNextFeature();
+        goesOn = feature.addr() != 0;
+        }
+      }
+
+    ShiftScaleFilterType::Pointer shiftScaleFilter = ShiftScaleFilterType::New();
+    shiftScaleFilter->SetInput( input );
+    shiftScaleFilter->SetShifts( measurement.meanMeasurementVector );
+    shiftScaleFilter->SetScales( measurement.stddevMeasurementVector );
+    shiftScaleFilter->Update();
+
+    listSamples.listSample = shiftScaleFilter->GetOutput();
+    listSamples.labeledListSample = target;
+    }
+
+  return listSamples;
+}
+
+
+}
+}
+
+#endif
+
+
-- 
GitLab