From 078b7a5a0b4c25512c9d993251b558d1b1d2e54f Mon Sep 17 00:00:00 2001 From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr> Date: Thu, 23 Feb 2017 16:55:15 +0100 Subject: [PATCH] ENH: Update TrainImagesClustering and Classifier to use SampleSelection. SampleSelection can now be used without input vector data. --- .../app/otbTrainImagesClassifier.cxx | 2 +- .../app/otbTrainImagesClustering.cxx | 93 +++++-------------- .../include/otbTrainImagesBase.h | 13 +-- 3 files changed, 33 insertions(+), 75 deletions(-) diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx index 3ed942dbc3..bb0f9176b2 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx @@ -115,7 +115,7 @@ public: ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs); ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv); } - SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation); + SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList); // Then train the model with extracted samples diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx index fdabcd1b08..e4ffd3b5db 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx @@ -34,6 +34,14 @@ public: InitSampling(); InitClassification( false ); + AddParameter( ParameterType_Float, "sample.percent", "Percentage of samples extract in images for " + "training and validation when only images are provided." ); + SetParameterDescription( "sample.percent", "Percentage of samples extract in images for " + "training and validation when only images are provided. This parameter is disable when vector data are provided" ); + SetDefaultParameterFloat( "sample.percent", 100.0 ); + SetMinimumParameterFloatValue( "sample.percent", 0.0 ); + SetMaximumParameterFloatValue( "sample.percent", 100.0 ); + // Doc example parameter settings SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" ); SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" ); @@ -51,8 +59,13 @@ public: { if( HasValue( "io.vd" ) ) { + MandatoryOff( "sample.percent" ); UpdatePolygonClassStatisticsParameters(); } + else + { + MandatoryOn( "sample.percent" ); + } } void DoExecute() ITK_OVERRIDE @@ -60,12 +73,12 @@ public: TrainFileNamesHandler fileNames; FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" ); - std::vector<std::string> vectorFileList = GetVectorFileList( GetParameterString( "io.out" ), fileNames ); + std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); unsigned long nbInputs = imageList->Size(); - if( nbInputs > vectorFileList.size() ) + if( !vectorFileList.empty() && nbInputs > vectorFileList.size() ) { otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." ); } @@ -90,28 +103,29 @@ public: SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation ); if( HasInputVector ) - { + { // Select and Extract samples for training with computed statistics and rates ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs ); ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt ); SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS ); - } + } else - { + { SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC ); - } + } // Select and Extract samples for validation with computed statistics and rates // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided - if( dedicatedValidation ) { - ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs); - ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv); + if( dedicatedValidation ) + { + ComputePolygonStatistics( imageList, validationVectorFileList, fileNames.polyStatValidOutputs ); + ComputeSamplingRate( fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv ); } - SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation); + SelectAndExtractValidationSamples( fileNames, imageList, validationVectorFileList ); // Then train the model with extracted samples - TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs); + TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs ); // cleanup if( IsParameterEnabled( "cleanup" ) ) @@ -130,63 +144,6 @@ private : UpdateInternalParameters( "polystat" ); } - - /** - * Retrieve input vector data if provided otherwise generate a default vector shape file for each image. - * \param output vector file path - * \param fileNames - * \return list of input vector data file names - */ - std::vector<std::string> GetVectorFileList(std::string output, TrainFileNamesHandler &fileNames) - { - std::vector<std::string> vectorFileList; - bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" ); - - // Retrieve provided input vector data if available. - if( !HasInputVector ) - { - FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); - unsigned int nbInputs = static_cast<unsigned int>(imageList->Size()); - - for( unsigned int i = 0; i < nbInputs; ++i ) - { - std::string name = output + "_vector_" + std::to_string( i ) + ".shp"; - GenerateVectorDataFile( imageList->GetNthElement( i ), name ); - fileNames.tmpVectorFileList.push_back( name ); - } - vectorFileList = fileNames.tmpVectorFileList; - SetParameterStringList( "io.vd", vectorFileList, false ); - UpdatePolygonClassStatisticsParameters(); - GetInternalApplication( "polystat" )->SetParameterString( "field", "fid" ); - } - else - { - vectorFileList = GetParameterStringList( "io.vd" ); - } - - return vectorFileList; - } - - - - void GenerateVectorDataFile(const FloatVectorImageListType::ObjectPointerType &floatVectorImage, std::string name) - { - typedef otb::ImageToEnvelopeVectorDataFilter<FloatVectorImageType, VectorDataType> ImageToEnvelopeFilterType; - typedef ImageToEnvelopeFilterType::OutputVectorDataType OutputVectorData; - typedef otb::VectorDataFileWriter<OutputVectorData> VectorDataWriter; - - ImageToEnvelopeFilterType::Pointer imageToEnvelopeVectorData = ImageToEnvelopeFilterType::New(); - imageToEnvelopeVectorData->SetInput( floatVectorImage ); - imageToEnvelopeVectorData->SetOutputProjectionRef( floatVectorImage->GetProjectionRef().c_str() ); - OutputVectorData::Pointer vectorData = imageToEnvelopeVectorData->GetOutput(); - - // write temporary generated vector file to disk. - VectorDataWriter::Pointer vectorDataFileWriter = VectorDataWriter::New(); - vectorDataFileWriter->SetInput( vectorData ); - vectorDataFileWriter->SetFileName( name.c_str() ); - vectorDataFileWriter->Write(); - } - }; } diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h index 4f5dd82d3c..4e9a0cf220 100644 --- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h @@ -327,7 +327,6 @@ protected: std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy) { GetInternalApplication( "select" )->SetParameterInputImage( "in", image ); - GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false ); GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false ); GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false ); @@ -338,10 +337,12 @@ protected: { case GEOMETRIC: GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false ); - GetInternalApplication( "select" )->SetParameterString( "strategy", "all", false ); + GetInternalApplication( "select" )->SetParameterString( "strategy", "percent", false ); + GetInternalApplication( "select" )->SetParameterFloat("strategy.percent.p", GetParameterFloat("sample.percent"), false); break; case CLASS: default: + GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false ); GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false ); GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false ); GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 ); @@ -365,19 +366,19 @@ protected: for( unsigned int i = 0; i < imageList->Size(); ++i ) { - SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileNames[i], fileNames.sampleOutputs[i], + std::string vectorFileName = vectorFileNames.empty() ? "" : vectorFileNames[i]; + SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileName, fileNames.sampleOutputs[i], fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy ); } } void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, - const std::vector<std::string> &validationVectorFileList, - bool dedicatedValidation) + const std::vector<std::string> &validationVectorFileList) { // In dedicated validation mode the by class sampling strategy and statistics are used. // Otherwise simply split training to validation samples corresponding to sample.vtr percentage. - if( dedicatedValidation ) + if( !validationVectorFileList.empty() ) { for( unsigned int i = 0; i < imageList->Size(); ++i ) { -- GitLab