From 90c054355860267f7b891f07c3fa6fa5bf0322b5 Mon Sep 17 00:00:00 2001 From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr> Date: Thu, 23 Feb 2017 14:39:45 +0100 Subject: [PATCH] ENH: Select strategy depending on provided Vector and do some refac. --- .../app/otbTrainImagesClassifier.cxx | 119 +++- .../app/otbTrainImagesClustering.cxx | 177 +++++- .../app/otbTrainVectorClustering.cxx | 5 +- .../include/otbTrainImagesBase.h | 525 +++++++----------- 4 files changed, 495 insertions(+), 331 deletions(-) diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx index 4a3b98b206..3ed942dbc3 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx @@ -5,15 +5,130 @@ namespace otb namespace Wrapper { -class TrainImagesClassifier : public TrainImagesBase<true> +class TrainImagesClassifier : public TrainImagesBase { public: typedef TrainImagesClassifier Self; - typedef TrainImagesBase<true> Superclass; + typedef TrainImagesBase Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; itkNewMacro( Self ) itkTypeMacro( Self, Superclass ) + + void DoInit() ITK_OVERRIDE + { + SetName( "TrainImagesClassifier" ); + SetDescription( "Train a classifier from multiple pairs of images and training vector data." ); + + // Documentation + SetDocName( "Train a classifier from multiple images" ); + SetDocLongDescription( + "This application performs a classifier training from multiple pairs of input images and training vector data. " + "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by " + "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field " + "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation " + "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio " + "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and " + "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the " + "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. " + "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels" + " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning " + "(2.3.1 and later)." ); + SetDocLimitations( "None" ); + SetDocAuthors( "OTB-Team" ); + SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " ); + + AddDocTag( Tags::Learning ); + + // Perform initialization + ClearApplications(); + InitIO(); + InitSampling(); + InitClassification( true ); + + + // Doc example parameter settings + SetDocExampleParameterValue("io.il", "QB_1_ortho.tif"); + SetDocExampleParameterValue("io.vd", "VectorData_QB1.shp"); + SetDocExampleParameterValue("io.imstat", "EstimateImageStatisticsQB1.xml"); + SetDocExampleParameterValue("sample.mv", "100"); + SetDocExampleParameterValue("sample.mt", "100"); + SetDocExampleParameterValue("sample.vtr", "0.5"); + SetDocExampleParameterValue("sample.vfn", "Class"); + SetDocExampleParameterValue("classifier", "libsvm"); + SetDocExampleParameterValue("classifier.libsvm.k", "linear"); + SetDocExampleParameterValue("classifier.libsvm.c", "1"); + SetDocExampleParameterValue("classifier.libsvm.opt", "false"); + SetDocExampleParameterValue("io.out", "svmModelQB1.txt"); + SetDocExampleParameterValue("io.confmatout", "svmConfusionMatrixQB1.csv"); + } + + void DoUpdateParameters() ITK_OVERRIDE + { + if( HasValue( "io.vd" ) ) + { + std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false ); + UpdateInternalParameters( "polystat" ); + } + } + + void DoExecute() ITK_OVERRIDE + { + TrainFileNamesHandler fileNames; + FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); + std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + unsigned long nbInputs = imageList->Size(); + + if( nbInputs > vectorFileList.size() ) + { + otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." ); + } + + // check if validation vectors are given + std::vector<std::string> validationVectorFileList; + bool dedicatedValidation = false; + if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) ) + { + validationVectorFileList = GetParameterStringList( "io.valid" ); + if( nbInputs > validationVectorFileList.size() ) + { + otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." ); + } + + dedicatedValidation = true; + } + + fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation ); + + // Compute final maximum sampling rates for both training and validation samples + SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation ); + + // Select and Extract samples for training with computed statistics and rates + ComputePolygonStatistics(imageList, vectorFileList, fileNames.polyStatTrainOutputs); + ComputeSamplingRate(fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt); + SelectAndExtractTrainSamples(fileNames, imageList, vectorFileList, SamplingStrategy::CLASS); + + // Select and Extract samples for validation with computed statistics and rates + // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided + if( dedicatedValidation ) { + ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs); + ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv); + } + SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation); + + + // Then train the model with extracted samples + TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs); + + // cleanup + if( IsParameterEnabled( "cleanup" ) ) + { + otbAppLogINFO( <<"Final clean-up ..." ); + fileNames.clear(); + } + } + }; } diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx index fed5b0775a..fdabcd1b08 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx @@ -5,15 +5,188 @@ namespace otb namespace Wrapper { -class TrainImagesClustering : public TrainImagesBase<false> +class TrainImagesClustering : public TrainImagesBase { public: typedef TrainImagesClustering Self; - typedef TrainImagesBase<false> Superclass; + typedef TrainImagesBase Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; itkNewMacro( Self ) itkTypeMacro( Self, Superclass ) + + void DoInit() ITK_OVERRIDE + { + SetName( "TrainImagesClustering" ); + SetDescription( "Train a classifier from multiple pairs of images and optional input training vector data." ); + + // Documentation + SetDocName( "Train a classifier from multiple images" ); + SetDocLongDescription( "TODO" ); + SetDocLimitations( "None" ); + SetDocAuthors( "OTB-Team" ); + SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " ); + + AddDocTag( Tags::Learning ); + + ClearApplications(); + InitIO(); + InitSampling(); + InitClassification( false ); + + // Doc example parameter settings + SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" ); + SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" ); + SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" ); + SetDocExampleParameterValue( "sample.mv", "100" ); + SetDocExampleParameterValue( "sample.mt", "100" ); + SetDocExampleParameterValue( "sample.vtr", "0.5" ); + SetDocExampleParameterValue( "sample.vfn", "Class" ); + SetDocExampleParameterValue( "classifier", "sharkkm" ); + SetDocExampleParameterValue( "classifier.sharkkm.k", "2" ); + SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" ); + } + + void DoUpdateParameters() ITK_OVERRIDE + { + if( HasValue( "io.vd" ) ) + { + UpdatePolygonClassStatisticsParameters(); + } + } + + void DoExecute() ITK_OVERRIDE + { + TrainFileNamesHandler fileNames; + FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); + bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" ); + std::vector<std::string> vectorFileList = GetVectorFileList( GetParameterString( "io.out" ), fileNames ); + + + unsigned long nbInputs = imageList->Size(); + + if( nbInputs > vectorFileList.size() ) + { + otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." ); + } + + // check if validation vectors are given + std::vector<std::string> validationVectorFileList; + bool dedicatedValidation = false; + if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) ) + { + validationVectorFileList = GetParameterStringList( "io.valid" ); + if( nbInputs > validationVectorFileList.size() ) + { + otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." ); + } + + dedicatedValidation = true; + } + + fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation ); + + // Compute final maximum sampling rates for both training and validation samples + SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation ); + + if( HasInputVector ) + { + // Select and Extract samples for training with computed statistics and rates + ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs ); + ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt ); + SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS ); + } + else + { + SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC ); + } + + // Select and Extract samples for validation with computed statistics and rates + // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided + if( dedicatedValidation ) { + ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs); + ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv); + } + SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation); + + + // Then train the model with extracted samples + TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs); + + // cleanup + if( IsParameterEnabled( "cleanup" ) ) + { + otbAppLogINFO( <<"Final clean-up ..." ); + fileNames.clear(); + } + } + +private : + + void UpdatePolygonClassStatisticsParameters() + { + std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false ); + UpdateInternalParameters( "polystat" ); + } + + + /** + * Retrieve input vector data if provided otherwise generate a default vector shape file for each image. + * \param output vector file path + * \param fileNames + * \return list of input vector data file names + */ + std::vector<std::string> GetVectorFileList(std::string output, TrainFileNamesHandler &fileNames) + { + std::vector<std::string> vectorFileList; + bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" ); + + // Retrieve provided input vector data if available. + if( !HasInputVector ) + { + FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); + unsigned int nbInputs = static_cast<unsigned int>(imageList->Size()); + + for( unsigned int i = 0; i < nbInputs; ++i ) + { + std::string name = output + "_vector_" + std::to_string( i ) + ".shp"; + GenerateVectorDataFile( imageList->GetNthElement( i ), name ); + fileNames.tmpVectorFileList.push_back( name ); + } + vectorFileList = fileNames.tmpVectorFileList; + SetParameterStringList( "io.vd", vectorFileList, false ); + UpdatePolygonClassStatisticsParameters(); + GetInternalApplication( "polystat" )->SetParameterString( "field", "fid" ); + } + else + { + vectorFileList = GetParameterStringList( "io.vd" ); + } + + return vectorFileList; + } + + + + void GenerateVectorDataFile(const FloatVectorImageListType::ObjectPointerType &floatVectorImage, std::string name) + { + typedef otb::ImageToEnvelopeVectorDataFilter<FloatVectorImageType, VectorDataType> ImageToEnvelopeFilterType; + typedef ImageToEnvelopeFilterType::OutputVectorDataType OutputVectorData; + typedef otb::VectorDataFileWriter<OutputVectorData> VectorDataWriter; + + ImageToEnvelopeFilterType::Pointer imageToEnvelopeVectorData = ImageToEnvelopeFilterType::New(); + imageToEnvelopeVectorData->SetInput( floatVectorImage ); + imageToEnvelopeVectorData->SetOutputProjectionRef( floatVectorImage->GetProjectionRef().c_str() ); + OutputVectorData::Pointer vectorData = imageToEnvelopeVectorData->GetOutput(); + + // write temporary generated vector file to disk. + VectorDataWriter::Pointer vectorDataFileWriter = VectorDataWriter::New(); + vectorDataFileWriter->SetInput( vectorData ); + vectorDataFileWriter->SetFileName( name.c_str() ); + vectorDataFileWriter->Write(); + } + }; } diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx index 49acbbc2b3..596dbef867 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx @@ -57,9 +57,8 @@ private: // Doc example parameter settings SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); - SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); - SetDocExampleParameterValue( "io.out", "svmModel.svm" ); - SetDocExampleParameterValue( "feat", "perimeter area width" ); + SetDocExampleParameterValue( "io.out", "kmeansModel.txt" ); + SetDocExampleParameterValue( "feat", "perimeter width area" ); } diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h index 5b6aca1460..4f5dd82d3c 100644 --- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h @@ -17,18 +17,21 @@ #ifndef otbTrainImagesBase_h #define otbTrainImagesBase_h + +#include "otbVectorDataFileWriter.h" #include "otbWrapperCompositeApplication.h" #include "otbWrapperApplicationFactory.h" -#include "otbOGRDataToSamplePositionFilter.h" +#include "otbStatisticsXMLFileWriter.h" +#include "otbImageToEnvelopeVectorDataFilter.h" #include "otbSamplingRateCalculator.h" +#include "otbOGRDataToSamplePositionFilter.h" namespace otb { namespace Wrapper { -template<bool IsSupervised = true> class TrainImagesBase : public CompositeApplication { public: @@ -48,11 +51,32 @@ public: protected: -private: - struct SamplingRates; + enum SamplingStrategy + { + CLASS, GEOMETRIC + }; + struct SamplingRates; class TrainFileNamesHandler; + void InitIO() + { + //Group IO + AddParameter( ParameterType_Group, "io", "Input and output data" ); + SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); + + AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" ); + SetParameterDescription( "io.il", "A list of input images." ); + AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" ); + SetParameterDescription( "io.vd", "A list of vector data to select the training samples." ); + + AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" ); + EnableParameter( "cleanup" ); + SetParameterDescription( "cleanup", + "If activated, the application will try to clean all temporary files it created" ); + MandatoryOff( "cleanup" ); + } + void InitSampling() { AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" ); @@ -131,6 +155,9 @@ private: SetParameterDescription( "io.valid", "A list of vector data to select the training samples." ); MandatoryOff( "io.valid" ); + if( !supervised ) + MandatoryOff( "io.vd" ); + ShareClassificationParams( supervised ); ConnectClassificationParams(); }; @@ -153,206 +180,31 @@ private: Connect( "select.rand", "training.rand" ); } - void DoUnsupervisedInit() - { - SetName( "TrainImagesClustering" ); - SetDescription( "Train a classifier from multiple pairs of images and training vector data." ); - - // Documentation - SetDocName( "Train a classifier from multiple images" ); - SetDocLongDescription( "TODO" ); - SetDocLimitations( "None" ); - SetDocAuthors( "OTB-Team" ); - SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " ); - - AddDocTag( Tags::Learning ); - - ClearApplications(); - InitSampling(); - InitClassification( IsSupervised ); - - // Hide sampling parameters if sample.vnf is not provided - MandatoryOn( "sample.mv" ); - MandatoryOn( "sample.mt" ); - MandatoryOn( "sample.vtr" ); - - - // Doc example parameter settings - SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" ); - SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" ); - SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" ); - SetDocExampleParameterValue( "sample.mv", "100" ); - SetDocExampleParameterValue( "sample.mt", "100" ); - SetDocExampleParameterValue( "sample.vtr", "0.5" ); - SetDocExampleParameterValue( "sample.vfn", "Class" ); - SetDocExampleParameterValue( "classifier", "sharkkm" ); - SetDocExampleParameterValue( "classifier.sharkkm.k", "2" ); - SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" ); - } - - void DoSupervisedInit() - { - SetName( "TrainImagesClassifier" ); - SetDescription( "Train a classifier from multiple pairs of images and training vector data." ); - - // Documentation - SetDocName( "Train a classifier from multiple images" ); - SetDocLongDescription( - "This application performs a classifier training from multiple pairs of input images and training vector data. " - "Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by " - "the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field " - "representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation " - "sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio " - "between the number of samples in training and validation sets. Two parameters allow managing the size of the training and " - "validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the " - "validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. " - "In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels" - " are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning " - "(2.3.1 and later)." ); - SetDocLimitations( "None" ); - SetDocAuthors( "OTB-Team" ); - SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " ); - - AddDocTag( Tags::Learning ); - - // Perform initialization - ClearApplications(); - InitSampling(); - InitClassification( IsSupervised ); - - // Doc example parameter settings - SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" ); - SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" ); - SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" ); - SetDocExampleParameterValue( "sample.mv", "100" ); - SetDocExampleParameterValue( "sample.mt", "100" ); - SetDocExampleParameterValue( "sample.vtr", "0.5" ); - SetDocExampleParameterValue( "sample.vfn", "Class" ); - SetDocExampleParameterValue( "classifier", "sharkkm" ); - SetDocExampleParameterValue( "classifier.sharkkm.k", "2" ); - SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" ); - } - - void DoInit() ITK_OVERRIDE - { - //Group IO - AddParameter( ParameterType_Group, "io", "Input and output data" ); - SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); - - AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" ); - SetParameterDescription( "io.il", "A list of input images." ); - AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" ); - SetParameterDescription( "io.vd", "A list of vector data to select the training samples." ); - - AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" ); - EnableParameter( "cleanup" ); - SetParameterDescription( "cleanup", - "If activated, the application will try to clean all temporary files it created" ); - - if( IsSupervised ) - DoSupervisedInit(); - else - DoUnsupervisedInit(); - - MandatoryOff( "cleanup" ); - } - - void DoUpdateParameters() ITK_OVERRIDE - { - if( HasValue( "io.vd" ) ) - { - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); - GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false ); - UpdateInternalParameters( "polystat" ); - } - } - - void DoExecute() ITK_OVERRIDE - { - FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); - unsigned long nbInputs = imageList->Size(); - - if( nbInputs > vectorFileList.size() ) - { - otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." ); - } - - // check if validation vectors are given - std::vector<std::string> validationVectorFileList; - bool dedicatedValidation = false; - if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) ) - { - validationVectorFileList = GetParameterStringList( "io.valid" ); - if( nbInputs > validationVectorFileList.size() ) - { - otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." ); - } - - if( !IsParameterEnabled( "sample.vnf" ) || !HasValue( "sample.vnf" ) ) - otbAppLogFATAL( "Missing class field name to use validation data." ); - - dedicatedValidation = true; - } - - TrainFileNamesHandler fileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation ); - - if( !IsSupervised && IsParameterEnabled( "sample.vfn" ) && HasValue( "sample.vfn" ) ) - { - fileNames.sampleTrainOutputs = vectorFileList; - fileNames.sampleValidOutputs = validationVectorFileList; - TrainModel( fileNames, imageList ); - } - else - { - ComputePolygonStatistics( fileNames, imageList, dedicatedValidation, vectorFileList, validationVectorFileList ); - SamplingRates rates = ComputeSamplingRates( dedicatedValidation ); - SamplingRateForTrainingAndValidation( fileNames, rates, dedicatedValidation ); - SelectAndExtractSamples( fileNames, imageList, dedicatedValidation, vectorFileList, validationVectorFileList ); - TrainModel( fileNames, imageList ); - } - - - // cleanup - if( IsParameterEnabled( "cleanup" ) ) - { - otbAppLogINFO( <<"Final clean-up ..." ); - fileNames.clear(); - } - } - /** - * Compute polygon statistics given provided strategy - * \param fileNames - * \param imageList - * \param dedicatedValidation + * Compute polygon statistics given provided strategy with PolygonClassStatistics class + * \param imageList list of input images + * \param vectorFileNames list of input vector file names + * \param statisticsFileNames list of out */ - void ComputePolygonStatistics(TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, - bool dedicatedValidation, std::vector<std::string> vectorFileList, - std::vector<std::string> validationVectorFileList) + void ComputePolygonStatistics(FloatVectorImageListType *imageList, const std::vector<std::string> &vectorFileNames, + const std::vector<std::string> &statisticsFileNames) { - for( unsigned int i = 0; i < imageList->Size(); i++ ) + unsigned int nbImages = static_cast<unsigned int>(imageList->Size()); + for( unsigned int i = 0; i < nbImages; i++ ) { GetInternalApplication( "polystat" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) ); - GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[i], false ); - GetInternalApplication( "polystat" )->SetParameterString( "out", fileNames.polyStatTrainOutputs[i], false ); + GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileNames[i], false ); + GetInternalApplication( "polystat" )->SetParameterString( "out", statisticsFileNames[i], false ); ExecuteInternal( "polystat" ); - // analyse polygons given for validation - if( dedicatedValidation ) - { - GetInternalApplication( "polystat" )->SetParameterString( "vec", validationVectorFileList[i], false ); - GetInternalApplication( "polystat" )->SetParameterString( "out", fileNames.polyStatValidOutputs[i], false ); - ExecuteInternal( "polystat" ); - } } } /** - * Compute sampling rates + * Compute final maximum training and validation * \param dedicatedValidation * \return SamplingRates final maximum training and final maximum validation */ - SamplingRates ComputeSamplingRates(bool dedicatedValidation) + SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation) { SamplingRates rates; GetInternalApplication( "rates" )->SetParameterString( "mim", "proportional", false ); @@ -401,29 +253,30 @@ private: return rates; } + /** - * Provide input/output images and strategy for the MultiImageSamplingRate rate application - * \param fileNames - * \param rates - * \param dedicatedValidation + * Compute rates using MultiImageSamplingRate application + * \param statisticsFileNames + * \param ratesFileName + * \param maximum final maximum value computed by ComputeFinalMaximumSamplingRates + * \sa ComputeFinalMaximumSamplingRates */ - void - SamplingRateForTrainingAndValidation(TrainFileNamesHandler &fileNames, SamplingRates rates, bool dedicatedValidation) + void ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames, const std::string &ratesFileName, + long maximum) { - // Sampling rates for training - GetInternalApplication( "rates" )->SetParameterStringList( "il", fileNames.polyStatTrainOutputs, false ); - GetInternalApplication( "rates" )->SetParameterString( "out", fileNames.rateTrainOut, false ); + // Sampling rates + GetInternalApplication( "rates" )->SetParameterStringList( "il", statisticsFileNames, false ); + GetInternalApplication( "rates" )->SetParameterString( "out", ratesFileName, false ); if( GetParameterInt( "sample.bm" ) != 0 ) { GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false ); } else { - if( rates.fmt > -1 ) + if( maximum > -1 ) { GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false ); - GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(rates.fmt), - false ); + GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(maximum), false ); } else { @@ -431,151 +284,172 @@ private: } } ExecuteInternal( "rates" ); - // Sampling rates for validation - if( dedicatedValidation ) + } + + /** + * Train the model with training and optional validation data samples + * \param imageList list of input images + * \param sampleTrainFileNames files names of the training samples + * \param sampleValidationFileNames file names of the validation sample + */ + void TrainModel(FloatVectorImageListType *imageList, const std::vector<std::string> &sampleTrainFileNames, + const std::vector<std::string> &sampleValidationFileNames) + { + GetInternalApplication( "training" )->SetParameterStringList( "io.vd", sampleTrainFileNames, false ); + if( !sampleValidationFileNames.empty() ) + GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", sampleValidationFileNames, false ); + + UpdateInternalParameters( "training" ); + // set field names + FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 ); + unsigned int nbBands = image->GetNumberOfComponentsPerPixel(); + std::vector<std::string> selectedNames; + for( unsigned int i = 0; i < nbBands; i++ ) { - GetInternalApplication( "rates" )->SetParameterStringList( "il", fileNames.polyStatValidOutputs, false ); - GetInternalApplication( "rates" )->SetParameterString( "out", fileNames.rateValidOut, false ); - if( GetParameterInt( "sample.bm" ) != 0 ) - { - GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false ); - } - else - { - if( rates.fmv > -1 ) - { - GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false ); - GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(rates.fmv) ); - } - else - { - GetInternalApplication( "rates" )->SetParameterString( "strategy", "all", false ); - } - } - ExecuteInternal( "rates" ); + std::ostringstream oss; + oss << i; + selectedNames.push_back( "value_" + oss.str() ); } + GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false ); + ExecuteInternal( "training" ); } /** - * Configure and extract samples for the SampleExtraction application. - * \param fileNames - * \param imageList - * \param dedicatedValidation + * Select samples by class or by geographic strategy + * \param image + * \param vectorFileName + * \param sampleFileName + * \param statisticsFileName + * \param ratesFileName + * \param strategy */ - void SelectAndExtractSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, - bool dedicatedValidation, const std::vector<std::string> &vectorFileList, - const std::vector<std::string> &validationVectorFileList) + void SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, std::string sampleFileName, + std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy) { - GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false ); - GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 ); - GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false ); + GetInternalApplication( "select" )->SetParameterInputImage( "in", image ); + GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false ); + GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false ); + GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false ); GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_", false ); + + // Change the selection strategy based on selected sampling strategy + switch( strategy ) + { + case GEOMETRIC: + GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false ); + GetInternalApplication( "select" )->SetParameterString( "strategy", "all", false ); + break; + case CLASS: + default: + GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false ); + GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false ); + GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 ); + GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false ); + GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", ratesFileName, false ); + break; + } + + // select sample positions + ExecuteInternal( "select" ); + // extract sample descriptors + ExecuteInternal( "extraction" ); + } + + /** + * Select and extract samples with the SampleSelection and SampleExtraction application. + */ + void SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, + std::vector<std::string> vectorFileNames, SamplingStrategy strategy) + { + for( unsigned int i = 0; i < imageList->Size(); ++i ) { - GetInternalApplication( "select" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) ); - GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileList[i], false ); - GetInternalApplication( "select" )->SetParameterString( "out", fileNames.sampleOutputs[i], false ); - GetInternalApplication( "select" )->SetParameterString( "instats", fileNames.polyStatTrainOutputs[i], false ); - GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", fileNames.ratesTrainOutputs[i], - false ); - // select sample positions - ExecuteInternal( "select" ); - // extract sample descriptors - ExecuteInternal( "extraction" ); + SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileNames[i], fileNames.sampleOutputs[i], + fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy ); + } + } - if( dedicatedValidation ) + + void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, + const std::vector<std::string> &validationVectorFileList, + bool dedicatedValidation) + { + // In dedicated validation mode the by class sampling strategy and statistics are used. + // Otherwise simply split training to validation samples corresponding to sample.vtr percentage. + if( dedicatedValidation ) + { + for( unsigned int i = 0; i < imageList->Size(); ++i ) { - GetInternalApplication( "select" )->SetParameterString( "vec", validationVectorFileList[i], false ); - GetInternalApplication( "select" )->SetParameterString( "out", fileNames.sampleValidOutputs[i], false ); - GetInternalApplication( "select" )->SetParameterString( "instats", fileNames.polyStatValidOutputs[i], false ); - GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", fileNames.ratesValidOutputs[i], - false ); - // select sample positions - ExecuteInternal( "select" ); - // extract sample descriptors - ExecuteInternal( "extraction" ); + SelectAndExtractSamples( imageList->GetNthElement( i ), validationVectorFileList[i], + fileNames.sampleValidOutputs[i], fileNames.polyStatValidOutputs[i], + fileNames.ratesValidOutputs[i], SamplingStrategy::CLASS ); } - else + } + else + { + for( unsigned int i = 0; i < imageList->Size(); ++i ) { - // Split between training and validation - ogr::DataSource::Pointer source = ogr::DataSource::New( fileNames.sampleOutputs[i], - ogr::DataSource::Modes::Read ); - ogr::DataSource::Pointer destTrain = ogr::DataSource::New( fileNames.sampleTrainOutputs[i], - ogr::DataSource::Modes::Overwrite ); - ogr::DataSource::Pointer destValid = ogr::DataSource::New( fileNames.sampleValidOutputs[i], - ogr::DataSource::Modes::Overwrite ); - // read sampling rates from ratesTrainOutputs[i] - SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); - rateCalculator->Read( fileNames.ratesTrainOutputs[i] ); - // Compute sampling rates for train and valid - const MapRateType &inputRates = rateCalculator->GetRatesByClass(); - MapRateType trainRates; - MapRateType validRates; - otb::SamplingRateCalculator::TripletType tpt; - for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) - { - double vtr = GetParameterFloat( "sample.vtr" ); - unsigned long total = std::min( it->second.Required, it->second.Tot ); - unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); - unsigned long neededTrain = total - neededValid; - tpt.Tot = total; - tpt.Required = neededTrain; - tpt.Rate = ( 1.0 - vtr ); - trainRates[it->first] = tpt; - tpt.Tot = neededValid; - tpt.Required = neededValid; - tpt.Rate = 1.0; - validRates[it->first] = tpt; - } - - // Use an otb::OGRDataToSamplePositionFilter with 2 outputs - PeriodicSamplerType::SamplerParameterType param; - param.Offset = 0; - param.MaxJitter = 0; - PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); - splitter->SetInput( imageList->GetNthElement( i ) ); - splitter->SetOGRData( source ); - splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); - splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); - splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] ); - splitter->SetLayerIndex( 0 ); - splitter->SetOriginFieldName( std::string( "" ) ); - splitter->SetSamplerParameters( param ); - splitter->GetStreamer()->SetAutomaticTiledStreaming( - static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); - AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); - splitter->Update(); + SplitTrainingAndValidationSamples( imageList->GetNthElement( i ), fileNames.sampleOutputs[i], + fileNames.sampleTrainOutputs[i], fileNames.sampleValidOutputs[i], + fileNames.ratesTrainOutputs[i] ); } } } - /** - * Train the model with training and validation data samples - * \param fileNames files names used for filters - * \param imageList list of input images - */ - void TrainModel(TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList) +private: + void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, + std::string sampleTrainFileName, std::string sampleValidFileName, + std::string ratesTrainFileName) { - GetInternalApplication( "training" )->SetParameterStringList( "io.vd", fileNames.sampleTrainOutputs, false ); - GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", fileNames.sampleValidOutputs, false ); - UpdateInternalParameters( "training" ); - // set field names - FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 ); - unsigned int nbBands = image->GetNumberOfComponentsPerPixel(); - std::vector<std::string> selectedNames; - for( unsigned int i = 0; i < nbBands; i++ ) + // Split between training and validation + ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read ); + ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite ); + ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite ); + // read sampling rates from ratesTrainOutputs + SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); + rateCalculator->Read( ratesTrainFileName ); + // Compute sampling rates for train and valid + const MapRateType &inputRates = rateCalculator->GetRatesByClass(); + MapRateType trainRates; + MapRateType validRates; + otb::SamplingRateCalculator::TripletType tpt; + for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) { - std::ostringstream oss; - oss << i; - selectedNames.push_back( "value_" + oss.str() ); + double vtr = GetParameterFloat( "sample.vtr" ); + unsigned long total = std::min( it->second.Required, it->second.Tot ); + unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); + unsigned long neededTrain = total - neededValid; + tpt.Tot = total; + tpt.Required = neededTrain; + tpt.Rate = ( 1.0 - vtr ); + trainRates[it->first] = tpt; + tpt.Tot = neededValid; + tpt.Required = neededValid; + tpt.Rate = 1.0; + validRates[it->first] = tpt; } - GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false ); - ExecuteInternal( "training" ); + + // Use an otb::OGRDataToSamplePositionFilter with 2 outputs + PeriodicSamplerType::SamplerParameterType param; + param.Offset = 0; + param.MaxJitter = 0; + PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); + splitter->SetInput( image ); + splitter->SetOGRData( source ); + splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); + splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); + splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] ); + splitter->SetLayerIndex( 0 ); + splitter->SetOriginFieldName( std::string( "" ) ); + splitter->SetSamplerParameters( param ); + splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); + AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); + splitter->Update(); } -private: +protected: struct SamplingRates { @@ -591,7 +465,7 @@ private: class TrainFileNamesHandler { public : - TrainFileNamesHandler(std::string outModel, size_t nbInputs, bool dedicatedValidation) + void CreateTemporaryFileNames(std::string outModel, size_t nbInputs, bool dedicatedValidation) { if( dedicatedValidation ) @@ -645,6 +519,8 @@ private: RemoveFile( sampleTrainOutputs[i] ); for( unsigned int i = 0; i < sampleValidOutputs.size(); i++ ) RemoveFile( sampleValidOutputs[i] ); + for( unsigned int i = 0; i < tmpVectorFileList.size(); i++ ) + RemoveFile( tmpVectorFileList[i] ); } public: @@ -655,6 +531,7 @@ private: std::vector<std::string> sampleOutputs; std::vector<std::string> sampleTrainOutputs; std::vector<std::string> sampleValidOutputs; + std::vector<std::string> tmpVectorFileList; std::string rateValidOut; std::string rateTrainOut; -- GitLab