diff --git a/CMake/OTBModuleHeaderTest.cmake b/CMake/OTBModuleHeaderTest.cmake index 6a1c72158a365e323c9b4f84529f03f8ccf38b3e..83d7479c3f92cef4c7a532d10c8e9c5cd6e2277b 100644 --- a/CMake/OTBModuleHeaderTest.cmake +++ b/CMake/OTBModuleHeaderTest.cmake @@ -26,7 +26,7 @@ if(NOT OTB_USE_OPENCV) SET(BANNED_HEADERS "${BANNED_HEADERS} otbDecisionTreeMachineLearningModelFactory.h otbDecisionTreeMachineLearningModel.h otbKNearestNeighborsMachineLearningModelFactory.h otbKNearestNeighborsMachineLearningModel.h otbRandomForestsMachineLearningModelFactory.h otbRandomForestsMachineLearningModel.h otbSVMMachineLearningModelFactory.h otbSVMMachineLearningModel.h otbGradientBoostedTreeMachineLearningModelFactory.h otbGradientBoostedTreeMachineLearningModel.h otbBoostMachineLearningModelFactory.h otbBoostMachineLearningModel.h otbNeuralNetworkMachineLearningModelFactory.h otbNeuralNetworkMachineLearningModel.h otbNormalBayesMachineLearningModelFactory.h otbNormalBayesMachineLearningModel.h otbRequiresOpenCVCheck.h otbOpenCVUtils.h otbCvRTreesWrapper.h") endif() if(NOT OTB_USE_SHARK) - SET(BANNED_HEADERS "${BANNED_HEADERS} otbSharkRandomForestsMachineLearningModel.h otbSharkRandomForestsMachineLearningModel.txx otbSharkUtils.h otbRequiresSharkCheck.h otbSharkRandomForestsMachineLearningModelFactory.h") + SET(BANNED_HEADERS "${BANNED_HEADERS} otbSharkRandomForestsMachineLearningModel.h otbSharkRandomForestsMachineLearningModel.txx otbSharkUtils.h otbRequiresSharkCheck.h otbSharkRandomForestsMachineLearningModelFactory.h otbSharkKMeansMachineLearningModel.h otbSharkKMeansMachineLearningModel.txx otbSharkKMeansMachineLearningModelFactory.h otbSharkKMeansMachineLearningModelFactory.txx") endif() if(NOT OTB_USE_LIBSVM) SET(BANNED_HEADERS "${BANNED_HEADERS} otbLibSVMMachineLearningModel.h otbLibSVMMachineLearningModelFactory.h") @@ -44,7 +44,7 @@ endif() macro( otb_module_headertest _name ) - if( NOT ${_name}_THIRD_PARTY + if( NOT ${_name}_THIRD_PARTY AND EXISTS ${${_name}_SOURCE_DIR}/include AND PYTHON_EXECUTABLE AND NOT (PYTHON_VERSION_STRING VERSION_LESS 2.6) diff --git a/Modules/Applications/AppClassification/app/CMakeLists.txt b/Modules/Applications/AppClassification/app/CMakeLists.txt index 7da4e25da879e9f8a8f00b806fbd22f789e6f476..79cc0eb6ec558e8b1e69ba007ee300842f95e5dc 100644 --- a/Modules/Applications/AppClassification/app/CMakeLists.txt +++ b/Modules/Applications/AppClassification/app/CMakeLists.txt @@ -50,11 +50,6 @@ otb_create_application( SOURCES otbTrainVectorClassifier.cxx LINK_LIBRARIES ${${otb-module}_LIBRARIES}) -otb_create_application( - NAME TrainVectorClustering - SOURCES otbTrainVectorClustering.cxx - LINK_LIBRARIES ${${otb-module}_LIBRARIES}) - otb_create_application( NAME ComputeConfusionMatrix SOURCES otbComputeConfusionMatrix.cxx @@ -80,11 +75,6 @@ otb_create_application( SOURCES otbTrainImagesClassifier.cxx LINK_LIBRARIES ${${otb-module}_LIBRARIES}) -otb_create_application( - NAME TrainImagesClustering - SOURCES otbTrainImagesClustering.cxx - LINK_LIBRARIES ${${otb-module}_LIBRARIES}) - otb_create_application( NAME TrainRegression SOURCES otbTrainRegression.cxx diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx index bb0f9176b29789139d9007cd63cd0ad27db8bab5..3c1bada305be2ec720d0ee0f32e49d571213fb18 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx @@ -44,43 +44,133 @@ public: ClearApplications(); InitIO(); InitSampling(); - InitClassification( true ); + InitClassification(); + AddDocTag( Tags::Learning ); // Doc example parameter settings - SetDocExampleParameterValue("io.il", "QB_1_ortho.tif"); - SetDocExampleParameterValue("io.vd", "VectorData_QB1.shp"); - SetDocExampleParameterValue("io.imstat", "EstimateImageStatisticsQB1.xml"); - SetDocExampleParameterValue("sample.mv", "100"); - SetDocExampleParameterValue("sample.mt", "100"); - SetDocExampleParameterValue("sample.vtr", "0.5"); - SetDocExampleParameterValue("sample.vfn", "Class"); - SetDocExampleParameterValue("classifier", "libsvm"); - SetDocExampleParameterValue("classifier.libsvm.k", "linear"); - SetDocExampleParameterValue("classifier.libsvm.c", "1"); - SetDocExampleParameterValue("classifier.libsvm.opt", "false"); - SetDocExampleParameterValue("io.out", "svmModelQB1.txt"); - SetDocExampleParameterValue("io.confmatout", "svmConfusionMatrixQB1.csv"); + SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" ); + SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" ); + SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" ); + SetDocExampleParameterValue( "sample.mv", "100" ); + SetDocExampleParameterValue( "sample.mt", "100" ); + SetDocExampleParameterValue( "sample.vtr", "0.5" ); + SetDocExampleParameterValue( "sample.vfn", "Class" ); + SetDocExampleParameterValue( "classifier", "libsvm" ); + SetDocExampleParameterValue( "classifier.libsvm.k", "linear" ); + SetDocExampleParameterValue( "classifier.libsvm.c", "1" ); + SetDocExampleParameterValue( "classifier.libsvm.opt", "false" ); + SetDocExampleParameterValue( "io.out", "svmModelQB1.txt" ); + SetDocExampleParameterValue( "io.confmatout", "svmConfusionMatrixQB1.csv" ); } void DoUpdateParameters() ITK_OVERRIDE { - if( HasValue( "io.vd" ) ) + if( HasValue( "io.vd" ) && IsParameterEnabled( "io.vd" )) + { + UpdatePolygonClassStatisticsParameters(); + } + + + // Change mandatory of input vector depending on supervised and unsupervised mode. + if( HasValue( "classifier" ) ) + { + UpdateInternalParameters( "training" ); + switch( trainVectorBase->GetClassifierCategory() ) + { + case TrainVectorBase::Unsupervised: + MandatoryOff( "io.vd" ); + break; + default: + case TrainVectorBase::Supervised: + MandatoryOn( "io.vd" ); + break; + } + } + + } + + /** + * Select and Extract samples for validation with computed statistics and rates. + * Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided. + * If no dedicated validation is provided the training is split corresponding to the sample.vtr parameter, + * in this case if no vector data have been provided, the training rates and statistics are computed + * on the selection and extraction training result. + * fileNames.sampleOutputs contains training data and after an ExtractValidationData training data will + * be split to fileNames.sampleTrainOutputs. + * \param imageList + * \param fileNames + * \param validationVectorFileList + * \param rates + * \param HasInputVector + */ + void ExtractValidationData(FloatVectorImageListType *imageList, TrainFileNamesHandler& fileNames, + std::vector<std::string> validationVectorFileList, + const SamplingRates& rates, bool HasInputVector ) + { + if( !validationVectorFileList.empty() ) // Compute class statistics and sampling rate of validation data if provided. + { + ComputePolygonStatistics( imageList, validationVectorFileList, fileNames.polyStatValidOutputs ); + ComputeSamplingRate( fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv ); + SelectAndExtractValidationSamples( fileNames, imageList, validationVectorFileList ); + if( HasInputVector ) // if input vector is provided the sampleTrainOutputs is the previously extracted sampleOutputs + fileNames.sampleTrainOutputs = fileNames.sampleOutputs; + } + else if(GetParameterFloat("sample.vtr") != 0.0)// Split training data to validation { - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); - GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false ); - UpdateInternalParameters( "polystat" ); + if( !HasInputVector ) // Compute one class statistics and sampling rate for the generated vector. + ComputePolygonStatistics( imageList, fileNames.sampleOutputs, fileNames.polyStatTrainOutputs ); + ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt ); + SplitTrainingToValidationSamples( fileNames, imageList ); + } + else // nothing to do, except update fileNames + { + fileNames.sampleTrainOutputs = fileNames.sampleOutputs; } } - void DoExecute() ITK_OVERRIDE + /** + * Extract Training data depending if input vector is provided + * \param imageList list of the image + * \param fileNames handler that contain filenames + * \param vectorFileList input vector file list (if provided + * \param rates + */ + void ExtractTrainData(FloatVectorImageListType *imageList, const TrainFileNamesHandler& fileNames, + std::vector<std::string> vectorFileList, + const SamplingRates& rates) + { + if( !vectorFileList.empty() ) // Select and Extract samples for training with computed statistics and rates + { + ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs ); + ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt ); + SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS ); + } + else // Select training samples base on geometric sampling if no input vector is provided + { + SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC, "fid" ); + } + } + + + void DoExecute() { TrainFileNamesHandler fileNames; + std::vector<std::string> vectorFileList; FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" ); + if(HasInputVector) + vectorFileList = GetParameterStringList( "io.vd" ); + + unsigned long nbInputs = imageList->Size(); - if( nbInputs > vectorFileList.size() ) + if( !HasInputVector && trainVectorBase->GetClassifierCategory() == TrainVectorBase::Supervised ) + { + otbAppLogFATAL( "Missing input vector data files" ); + } + + if( !vectorFileList.empty() && nbInputs > vectorFileList.size() ) { otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." ); } @@ -104,22 +194,11 @@ public: // Compute final maximum sampling rates for both training and validation samples SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation ); - // Select and Extract samples for training with computed statistics and rates - ComputePolygonStatistics(imageList, vectorFileList, fileNames.polyStatTrainOutputs); - ComputeSamplingRate(fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt); - SelectAndExtractTrainSamples(fileNames, imageList, vectorFileList, SamplingStrategy::CLASS); - - // Select and Extract samples for validation with computed statistics and rates - // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided - if( dedicatedValidation ) { - ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs); - ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv); - } - SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList); - + ExtractTrainData(imageList, fileNames, vectorFileList, rates); + ExtractValidationData(imageList, fileNames, validationVectorFileList, rates, HasInputVector); // Then train the model with extracted samples - TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs); + TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs ); // cleanup if( IsParameterEnabled( "cleanup" ) ) @@ -129,6 +208,15 @@ public: } } +private : + + void UpdatePolygonClassStatisticsParameters() + { + std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); + GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false ); + UpdateInternalParameters( "polystat" ); + } + }; } diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx deleted file mode 100644 index 9d819e01ecab06ce4b268110bc6bb53edd974198..0000000000000000000000000000000000000000 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClustering.cxx +++ /dev/null @@ -1,162 +0,0 @@ -#include "otbTrainImagesBase.h" - -namespace otb -{ -namespace Wrapper -{ - -class TrainImagesClustering : public TrainImagesBase -{ -public: - typedef TrainImagesClustering Self; - typedef TrainImagesBase Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; - itkNewMacro( Self ) - itkTypeMacro( Self, Superclass ) - - void DoInit() ITK_OVERRIDE - { - SetName( "TrainImagesClustering" ); - SetDescription( "Train a classifier from multiple pairs of images and optional input training vector data." ); - - // Documentation - SetDocName( "Train a classifier from multiple images" ); - SetDocLongDescription( "TODO" ); - SetDocLimitations( "None" ); - SetDocAuthors( "OTB-Team" ); - SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " ); - - AddDocTag( Tags::Learning ); - - ClearApplications(); - InitIO(); - InitSampling(); - InitClassification( false ); - - AddParameter( ParameterType_Float, "sample.percent", "Percentage of samples extract in images for " - "training and validation when only images are provided." ); - SetParameterDescription( "sample.percent", "Percentage of samples extract in images for " - "training and validation when only images are provided. This parameter is disable when vector data are provided" ); - SetDefaultParameterFloat( "sample.percent", 1.0 ); - SetMinimumParameterFloatValue( "sample.percent", 0.0 ); - SetMaximumParameterFloatValue( "sample.percent", 1.0 ); - - // Doc example parameter settings - SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" ); - SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" ); - SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" ); - SetDocExampleParameterValue( "sample.mv", "100" ); - SetDocExampleParameterValue( "sample.mt", "100" ); - SetDocExampleParameterValue( "sample.vtr", "0.5" ); - SetDocExampleParameterValue( "sample.vfn", "Class" ); - SetDocExampleParameterValue( "classifier", "sharkkm" ); - SetDocExampleParameterValue( "classifier.sharkkm.k", "2" ); - SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" ); - } - - void DoUpdateParameters() ITK_OVERRIDE - { - if( HasValue( "io.vd" ) ) - { - MandatoryOff( "sample.percent" ); - UpdatePolygonClassStatisticsParameters(); - } - else - { - MandatoryOn( "sample.percent" ); - } - } - - void DoExecute() ITK_OVERRIDE - { - TrainFileNamesHandler fileNames; - FloatVectorImageListType *imageList = GetParameterImageList( "io.il" ); - bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" ); - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); - - - unsigned long nbInputs = imageList->Size(); - - if( !vectorFileList.empty() && nbInputs > vectorFileList.size() ) - { - otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." ); - } - - // check if validation vectors are given - std::vector<std::string> validationVectorFileList; - bool dedicatedValidation = false; - if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) ) - { - validationVectorFileList = GetParameterStringList( "io.valid" ); - if( nbInputs > validationVectorFileList.size() ) - { - otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." ); - } - - dedicatedValidation = true; - } - - fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation ); - - // Compute final maximum sampling rates for both training and validation samples - SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation ); - - if( HasInputVector ) - { - // Select and Extract samples for training with computed statistics and rates - ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs ); - ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt ); - SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS ); - } - else - { - SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC, "fid" ); - } - - // Select and Extract samples for validation with computed statistics and rates. - // Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided - // If no dedicated validation is provided the training is split corresponding to the sample.vtr parameter - // In this case if no vector data have been provided, the training rates and statistics are computed - // on the selection and extraction training result. - if( dedicatedValidation ) - { - ComputePolygonStatistics( imageList, validationVectorFileList, fileNames.polyStatValidOutputs ); - ComputeSamplingRate( fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv ); - } - else if(!HasInputVector) - { - ComputePolygonStatistics( imageList, fileNames.sampleOutputs, fileNames.polyStatTrainOutputs ); - ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt ); - } - - - // Extract or split validation vector data. - SelectAndExtractValidationSamples( fileNames, imageList, validationVectorFileList ); - - // Then train the model with extracted samples - TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs ); - - // cleanup - if( IsParameterEnabled( "cleanup" ) ) - { - otbAppLogINFO( <<"Final clean-up ..." ); - fileNames.clear(); - } - } - -private : - - void UpdatePolygonClassStatisticsParameters() - { - std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); - GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false ); - UpdateInternalParameters( "polystat" ); - } - -}; - -} -} - -OTB_APPLICATION_EXPORT( otb::Wrapper::TrainImagesClustering ) \ No newline at end of file diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index 9e26e768962bcc06761a0e6a92c49c91353026e0..6be11377009ec2085cf3f211d76696b4a75c578a 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -45,37 +45,10 @@ public: typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType; typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType; -protected : - TrainVectorClassifier() : TrainVectorBase() - { - m_ClassifierCategory = Supervised; - } - private: void DoTrainInit() { - SetName( "TrainVectorClassifier" ); - SetDescription( "Train a classifier based on labeled geometries and a list of features to consider." ); - - SetDocName( "Train Vector Classifier" ); - SetDocLongDescription( "This application trains a classifier based on " - "labeled geometries and a list of features to consider for classification." ); - SetDocLimitations( " " ); - SetDocAuthors( "OTB Team" ); - SetDocSeeAlso( " " ); - - // Add a new parameter to compute confusion matrix - AddParameter( ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix" ); - SetParameterDescription( "io.confmatout", "Output file containing the confusion matrix (.csv format)." ); - MandatoryOff( "io.confmatout" ); - - // Doc example parameter settings - SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); - SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); - SetDocExampleParameterValue( "io.out", "svmModel.svm" ); - SetDocExampleParameterValue( "feat", "perimeter area width" ); - SetDocExampleParameterValue( "cfield", "predicted" ); - + // Nothing to do here } void DoTrainUpdateParameters() @@ -86,46 +59,36 @@ private: void DoBeforeTrainExecute() { // Enforce the need of class field name in supervised mode - featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); - - if( featuresInfo.m_SelectedCFieldIdx.empty() && m_ClassifierCategory == Supervised ) + if (GetClassifierCategory() == Supervised) { - otbAppLogFATAL( << "No field has been selected for data labelling!" ); + featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) ); + + if( featuresInfo.m_SelectedCFieldIdx.empty() ) + { + otbAppLogFATAL( << "No field has been selected for data labelling!" ); + } } } void DoAfterTrainExecute() { - ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionmatrix( predictedList, - classificationListSamples.labeledListSample ); - WriteConfusionMatrix( confMatCalc ); - } - - ListSamples ExtractClassificationListSamples(const StatisticsMeasurement &measurement) - { - ListSamples performanceSample; - ListSamples validationListSamples = ExtractListSamples( "valid.vd", "valid.layer", measurement ); - //Test the input validation set size - if( validationListSamples.labeledListSample->Size() != 0 ) + if (GetClassifierCategory() == Supervised) { - performanceSample.listSample = validationListSamples.listSample; - performanceSample.labeledListSample = validationListSamples.labeledListSample; + ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( predictedList, + classificationListSamples.labeledListSample ); + WriteConfusionMatrix( confMatCalc ); } else { - otbAppLogWARNING( - "The validation set is empty. The performance estimation is done using the input training set in this case." ); - performanceSample.listSample = trainingListSamples.listSample; - performanceSample.labeledListSample = trainingListSamples.labeledListSample; + // TODO Compute Contingency Table } - - return performanceSample; } + ConfusionMatrixCalculatorType::Pointer - ComputeConfusionmatrix(const TargetListSampleType::Pointer &predictedListSample, + ComputeConfusionMatrix(const TargetListSampleType::Pointer &predictedListSample, const TargetListSampleType::Pointer &performanceLabeledListSample) { ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New(); diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx deleted file mode 100644 index 596dbef867bdc828e8a7023f59dc7573139b8e4d..0000000000000000000000000000000000000000 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClustering.cxx +++ /dev/null @@ -1,86 +0,0 @@ -/*========================================================================= - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - - =========================================================================*/ -#include "otbTrainVectorBase.h" - -namespace otb -{ -namespace Wrapper -{ - -class TrainVectorClustering : public TrainVectorBase -{ -public: - typedef TrainVectorClustering Self; - typedef TrainVectorBase Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; - itkNewMacro( Self ) - - itkTypeMacro( Self, Superclass ) - - typedef Superclass::SampleType SampleType; - typedef Superclass::ListSampleType ListSampleType; - typedef Superclass::TargetListSampleType TargetListSampleType; - -protected : - TrainVectorClustering() : TrainVectorBase() - { - m_ClassifierCategory = Unsupervised; - } - -private: - void DoTrainInit() - { - SetName( "TrainVectorClustering" ); - SetDescription( "Train a classifier based on labeled or unlabeled geometries and a list of features to consider." ); - - SetDocName( "Train Vector Clustering" ); - SetDocLongDescription( "This application trains a classifier based on " - "labeled or unlabeled geometries and a list of features to consider for classification." ); - SetDocLimitations( " " ); - SetDocAuthors( "OTB Team" ); - SetDocSeeAlso( " " ); - - // Doc example parameter settings - SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); - SetDocExampleParameterValue( "io.out", "kmeansModel.txt" ); - SetDocExampleParameterValue( "feat", "perimeter width area" ); - - } - - void DoTrainUpdateParameters() - { - // Nothing to do here - } - - void DoBeforeTrainExecute() - { - // Nothing to do here - } - - void DoAfterTrainExecute() - { - // Nothing to do here - } - - - -}; -} -} - -OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorClustering ) diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h index 9f8bbcdf7b005982f1ff754c07c4220d5565a053..f7a7c3e65441eb83ed3d0a9063aea7cc68d5a490 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h @@ -142,7 +142,23 @@ public: typedef otb::SharkRandomForestsMachineLearningModel<InputValueType, OutputValueType> SharkRandomForestType; typedef otb::SharkKMeansMachineLearningModel<InputValueType, OutputValueType> SharkKMeansType; #endif - + + itkGetConstReferenceMacro(SupervisedClassifier, std::vector<std::string>); + itkGetConstReferenceMacro(UnsupervisedClassifier, std::vector<std::string>); + + + enum ClassifierCategory{ + Supervised, + Unsupervised + }; + + /** + * Retrieve the classifier category (supervisde or unsupervised) + * based on the select algorithm from the classifier choice. + * @return ClassifierCategory the classifier category + */ + ClassifierCategory GetClassifierCategory(); + protected: LearningApplicationBase(); @@ -162,28 +178,24 @@ protected: /** Init method that creates all the parameters for machine learning models */ void DoInit() ITK_OVERRIDE; + /** Init method that creates all the parameters for machine learning models */ + void DoUpdateParameters() ITK_OVERRIDE; + /** Flag to switch between classification and regression mode. * False by default, child classes may change it in their constructor */ bool m_RegressionFlag; - /** enum use to selected classifier category */ - enum ClassifierCategory { - Supervised, - Unsupervised - }; - - /** Enum to switch between unsupervised or supervised classification. - * Supervised by default, child classes may change it in their constructor */ - ClassifierCategory m_ClassifierCategory; private: /** Specific Init and Train methods for each machine learning model */ /** Init Parameters for Supervised Classifier */ void InitSupervisedClassifierParams(); + std::vector<std::string> m_SupervisedClassifier; /** Init Parameters for Unsupervised Classifier */ void InitUnsupervisedClassifierParams(); + std::vector<std::string> m_UnsupervisedClassifier; //@{ #ifdef OTB_USE_LIBSVM diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx index 1b75305a570de4a3efab5666c6131fe7e52563bf..adbe6db9ca2cdbc2a00b6f9a193e0f6f9c892d2f 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx @@ -28,7 +28,8 @@ namespace Wrapper template <class TInputValue, class TOutputValue> LearningApplicationBase<TInputValue,TOutputValue> -::LearningApplicationBase() : m_RegressionFlag(false), m_ClassifierCategory(Supervised) +::LearningApplicationBase() : m_RegressionFlag(false) + { } @@ -50,17 +51,36 @@ LearningApplicationBase<TInputValue,TOutputValue> AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training"); SetParameterDescription("classifier", "Choice of the classifier to use for the training."); - switch(m_ClassifierCategory) - { - case Unsupervised: - InitUnsupervisedClassifierParams(); - break; - case Supervised: - default : - InitSupervisedClassifierParams(); - } + AddParameter(ParameterType_Choice, "category", "Type of classifier use for the training (supervised or unsupervised"); + SetParameterDescription("category", "Choice of the classifier type to use for the training, " + "choice is supervised or unsupervised."); + + InitSupervisedClassifierParams(); + m_SupervisedClassifier = GetChoiceKeys("classifier"); + + InitUnsupervisedClassifierParams(); + std::vector<std::string> allClassifier = GetChoiceKeys("classifier"); + m_UnsupervisedClassifier.assign(allClassifier.begin() + m_SupervisedClassifier.size(), allClassifier.end()); } +template <class TInputValue, class TOutputValue> +typename LearningApplicationBase<TInputValue,TOutputValue>::ClassifierCategory +LearningApplicationBase<TInputValue,TOutputValue> +::GetClassifierCategory() +{ + bool foundUnsupervised = + std::find(m_UnsupervisedClassifier.begin(), m_UnsupervisedClassifier.end(), + GetParameterString("classifier")) != m_UnsupervisedClassifier.end(); + return foundUnsupervised ? Unsupervised : Supervised; +} + +template <class TInputValue, class TOutputValue> +void +LearningApplicationBase<TInputValue,TOutputValue> +::DoUpdateParameters() +{ +}; + template <class TInputValue, class TOutputValue> void LearningApplicationBase<TInputValue,TOutputValue> diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h index 0921696cbb615e6afec3dc01ca6c4b5b46c1a20f..be9ad1425ff83bb2f6c8a7996a95031971212d26 100644 --- a/Modules/Applications/AppClassification/include/otbTrainImagesBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.h @@ -1,23 +1,26 @@ -/*========================================================================= - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - - =========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbTrainImagesBase_h #define otbTrainImagesBase_h - +#include "otbTrainVectorBase.h" #include "otbVectorDataFileWriter.h" #include "otbWrapperCompositeApplication.h" #include "otbWrapperApplicationFactory.h" @@ -32,6 +35,15 @@ namespace otb namespace Wrapper { +/** \class TrainImagesBase + * \brief Base class for the TrainImagesBaseClassifier and Clustering + * + * This class intends to hold common input/output parameters and + * composite application connection for both supervised and unsupervised + * model training. + * + * \ingroup OTBAppClassification + */ class TrainImagesBase : public CompositeApplication { public: @@ -55,131 +67,24 @@ protected: { CLASS, GEOMETRIC }; - struct SamplingRates; - class TrainFileNamesHandler; - void InitIO() - { - //Group IO - AddParameter( ParameterType_Group, "io", "Input and output data" ); - SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); - - AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" ); - SetParameterDescription( "io.il", "A list of input images." ); - AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" ); - SetParameterDescription( "io.vd", "A list of vector data to select the training samples." ); - - AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" ); - EnableParameter( "cleanup" ); - SetParameterDescription( "cleanup", - "If activated, the application will try to clean all temporary files it created" ); - MandatoryOff( "cleanup" ); - } - - void InitSampling() - { - AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" ); - AddApplication( "MultiImageSamplingRate", "rates", "Sampling rates" ); - AddApplication( "SampleSelection", "select", "Sample selection" ); - AddApplication( "SampleExtraction", "extraction", "Sample extraction" ); - - // Sampling settings - AddParameter( ParameterType_Group, "sample", "Training and validation samples parameters" ); - SetParameterDescription( "sample", - "This group of parameters allows you to set training and validation sample lists parameters." ); - AddParameter( ParameterType_Int, "sample.mt", "Maximum training sample size per class" ); - SetDefaultParameterInt( "sample.mt", 1000 ); - SetParameterDescription( "sample.mt", "Maximum size per class (in pixels) of " - "the training sample list (default = 1000) (no limit = -1). If equal to -1," - " then the maximal size of the available training sample list per class " - "will be equal to the surface area of the smallest class multiplied by the" - " training sample ratio." ); - AddParameter( ParameterType_Int, "sample.mv", "Maximum validation sample size per class" ); - SetDefaultParameterInt( "sample.mv", 1000 ); - SetParameterDescription( "sample.mv", "Maximum size per class (in pixels) of " - "the validation sample list (default = 1000) (no limit = -1). If equal to -1," - " then the maximal size of the available validation sample list per class " - "will be equal to the surface area of the smallest class multiplied by the " - "validation sample ratio." ); - AddParameter( ParameterType_Int, "sample.bm", "Bound sample number by minimum" ); - SetDefaultParameterInt( "sample.bm", 1 ); - SetParameterDescription( "sample.bm", "Bound the number of samples for each " - "class by the number of available samples by the smaller class. Proportions " - "between training and validation are respected. Default is true (=1)." ); - AddParameter( ParameterType_Float, "sample.vtr", "Training and validation sample ratio" ); - SetParameterDescription( "sample.vtr", "Ratio between training and validation samples (0.0 = all training, 1.0 = " - "all validation) (default = 0.5)." ); - SetParameterFloat( "sample.vtr", 0.5, false ); - SetMaximumParameterFloatValue( "sample.vtr", 1.0 ); - SetMinimumParameterFloatValue( "sample.vtr", 0.0 ); - - ShareSamplingParameters(); - ConnectSamplingParameters(); - } - - void ShareSamplingParameters() - { - // hide sampling parameters - //ShareParameter("sample.strategy","rates.strategy"); - //ShareParameter("sample.mim","rates.mim"); - ShareParameter( "ram", "polystat.ram" ); - ShareParameter( "elev", "polystat.elev" ); - ShareParameter( "sample.vfn", "polystat.field" ); - } - - void ConnectSamplingParameters() - { - Connect( "extraction.field", "polystat.field" ); - Connect( "extraction.layer", "polystat.layer" ); - - Connect( "select.ram", "polystat.ram" ); - Connect( "extraction.ram", "polystat.ram" ); - - Connect( "select.field", "polystat.field" ); - Connect( "select.layer", "polystat.layer" ); - Connect( "select.elev", "polystat.elev" ); - - Connect( "extraction.in", "select.in" ); - Connect( "extraction.vec", "select.out" ); - } - - void InitClassification(bool supervised) - { - if( supervised ) - AddApplication( "TrainVectorClassifier", "training", "Model training" ); - else - AddApplication( "TrainVectorClustering", "training", "Model training" ); - - AddParameter( ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List" ); - SetParameterDescription( "io.valid", "A list of vector data to select the training samples." ); - MandatoryOff( "io.valid" ); - - if( !supervised ) - MandatoryOff( "io.vd" ); - - ShareClassificationParams( supervised ); - ConnectClassificationParams(); - }; - - void ShareClassificationParams(bool supervised) - { - ShareParameter( "io.imstat", "training.io.stats" ); - ShareParameter( "io.out", "training.io.out" ); - - ShareParameter( "classifier", "training.classifier" ); - ShareParameter( "rand", "training.rand" ); + /** + * Initialize all the input and output parameter used for the train images + */ + void InitIO(); - if( supervised ) - ShareParameter( "io.confmatout", "training.io.confmatout" ); - } + /** + * Initialize sampling related application and parameters + */ + void InitSampling(); - void ConnectClassificationParams() - { - Connect( "training.cfield", "polystat.field" ); - Connect( "select.rand", "training.rand" ); - } + void ShareSamplingParameters(); + void ConnectSamplingParameters(); + void InitClassification(); + void ShareClassificationParams(); + void ConnectClassificationParams(); /** * Compute polygon statistics given provided strategy with PolygonClassStatistics class @@ -188,71 +93,14 @@ protected: * \param statisticsFileNames list of out */ void ComputePolygonStatistics(FloatVectorImageListType *imageList, const std::vector<std::string> &vectorFileNames, - const std::vector<std::string> &statisticsFileNames) - { - unsigned int nbImages = static_cast<unsigned int>(imageList->Size()); - for( unsigned int i = 0; i < nbImages; i++ ) - { - GetInternalApplication( "polystat" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) ); - GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileNames[i], false ); - GetInternalApplication( "polystat" )->SetParameterString( "out", statisticsFileNames[i], false ); - ExecuteInternal( "polystat" ); - } - } + const std::vector<std::string> &statisticsFileNames); /** * Compute final maximum training and validation * \param dedicatedValidation * \return SamplingRates final maximum training and final maximum validation */ - SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation) - { - SamplingRates rates; - GetInternalApplication( "rates" )->SetParameterString( "mim", "proportional", false ); - double vtr = GetParameterFloat( "sample.vtr" ); - long mt = GetParameterInt( "sample.mt" ); - long mv = GetParameterInt( "sample.mv" ); - // compute final maximum training and final maximum validation - // By default take all samples (-1 means all samples) - rates.fmt = -1; - rates.fmv = -1; - if( GetParameterInt( "sample.bm" ) == 0 ) - { - if( dedicatedValidation ) - { - // fmt and fmv will be used separately - rates.fmt = mt; - rates.fmv = mv; - if( mt > -1 && mv <= -1 && vtr < 0.99999 ) - { - rates.fmv = static_cast<long>(( double ) mt * vtr / ( 1.0 - vtr )); - } - if( mt <= -1 && mv > -1 && vtr > 0.00001 ) - { - rates.fmt = static_cast<long>(( double ) mv * ( 1.0 - vtr ) / vtr); - } - } - else - { - // only fmt will be used for both training and validation samples - // So we try to compute the total number of samples given input - // parameters mt, mv and vtr. - if( mt > -1 && mv > -1 ) - { - rates.fmt = mt + mv; - } - if( mt > -1 && mv <= -1 && vtr < 0.99999 ) - { - rates.fmt = static_cast<long>(( double ) mt / ( 1.0 - vtr )); - } - if( mt <= -1 && mv > -1 && vtr > 0.00001 ) - { - rates.fmt = static_cast<long>(( double ) mv / vtr); - } - } - } - return rates; - } + SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation); /** @@ -262,31 +110,9 @@ protected: * \param maximum final maximum value computed by ComputeFinalMaximumSamplingRates * \sa ComputeFinalMaximumSamplingRates */ - void ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames, const std::string &ratesFileName, - long maximum) - { - // Sampling rates - GetInternalApplication( "rates" )->SetParameterStringList( "il", statisticsFileNames, false ); - GetInternalApplication( "rates" )->SetParameterString( "out", ratesFileName, false ); - if( GetParameterInt( "sample.bm" ) != 0 ) - { - GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false ); - } - else - { - if( maximum > -1 ) - { - GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false ); - GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(maximum), false ); - } - else - { - GetInternalApplication( "rates" )->SetParameterString( "strategy", "all", false ); - } - } - ExecuteInternal( "rates" ); - } - + void ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames, + const std::string &ratesFileName, + long maximum); /** * Train the model with training and optional validation data samples * \param imageList list of input images @@ -294,26 +120,7 @@ protected: * \param sampleValidationFileNames file names of the validation sample */ void TrainModel(FloatVectorImageListType *imageList, const std::vector<std::string> &sampleTrainFileNames, - const std::vector<std::string> &sampleValidationFileNames) - { - GetInternalApplication( "training" )->SetParameterStringList( "io.vd", sampleTrainFileNames, false ); - if( !sampleValidationFileNames.empty() ) - GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", sampleValidationFileNames, false ); - - UpdateInternalParameters( "training" ); - // set field names - FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 ); - unsigned int nbBands = image->GetNumberOfComponentsPerPixel(); - std::vector<std::string> selectedNames; - for( unsigned int i = 0; i < nbBands; i++ ) - { - std::ostringstream oss; - oss << i; - selectedNames.push_back( "value_" + oss.str() ); - } - GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false ); - ExecuteInternal( "training" ); - } + const std::vector<std::string> &sampleValidationFileNames); /** * Select samples by class or by geographic strategy @@ -326,143 +133,60 @@ protected: */ void SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, std::string sampleFileName, std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy, - std::string selectedField = "") - { - GetInternalApplication( "select" )->SetParameterInputImage( "in", image ); - GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false ); - - // Change the selection strategy based on selected sampling strategy - switch( strategy ) - { - case GEOMETRIC: - GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false ); - GetInternalApplication( "select" )->SetParameterString( "strategy", "percent", false ); - GetInternalApplication( "select" )->SetParameterFloat( "strategy.percent.p", - GetParameterFloat( "sample.percent" ), false ); - break; - case CLASS: - default: - GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false ); - GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false ); - GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false ); - GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 ); - GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false ); - GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", ratesFileName, false ); - break; - } - - // select sample positions - ExecuteInternal( "select" ); - - GetInternalApplication( "extraction" )->SetParameterString( "vec", sampleFileName, false ); - UpdateInternalParameters( "extraction" ); - if( !selectedField.empty() ) - GetInternalApplication( "extraction" )->SetParameterString( "field", selectedField, false ); - - GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false ); - GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_", false ); - - // extract sample descriptors - ExecuteInternal( "extraction" ); - } - + std::string selectedField = ""); /** * Select and extract samples with the SampleSelection and SampleExtraction application. + * \param fileNames + * \param imageList + * \param vectorFileNames + * \param strategy the strategy used for selection (by class or with geometry) + * \param selectedFieldName */ void SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, std::vector<std::string> vectorFileNames, SamplingStrategy strategy, - std::string selectedFieldName = "") - { - - for( unsigned int i = 0; i < imageList->Size(); ++i ) - { - std::string vectorFileName = vectorFileNames.empty() ? "" : vectorFileNames[i]; - SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileName, fileNames.sampleOutputs[i], - fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy, - selectedFieldName ); - } - } + std::string selectedFieldName = ""); + /** + * Function used to select validation samples based on a defined strategy (geometric in unsupervised mode) + * and extract them. With dedicated validation the 'by class' sampling strategy and statistics are used. + * Otherwise this function split training to validation samples corresponding to sample.vtr percentage. + * or do nothing if this percentage is == 0 + * \param fileNames + * \param imageList + * \param validationVectorFileList optional validation vector file for each images + */ void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList, - const std::vector<std::string> &validationVectorFileList = std::vector<std::string>()) - { - // In dedicated validation mode the by class sampling strategy and statistics are used. - // Otherwise simply split training to validation samples corresponding to sample.vtr percentage. - if( !validationVectorFileList.empty() ) - { - for( unsigned int i = 0; i < imageList->Size(); ++i ) - { - SelectAndExtractSamples( imageList->GetNthElement( i ), validationVectorFileList[i], - fileNames.sampleValidOutputs[i], fileNames.polyStatValidOutputs[i], - fileNames.ratesValidOutputs[i], SamplingStrategy::CLASS ); - } - } - else - { - for( unsigned int i = 0; i < imageList->Size(); ++i ) - { - SplitTrainingAndValidationSamples( imageList->GetNthElement( i ), fileNames.sampleOutputs[i], - fileNames.sampleTrainOutputs[i], fileNames.sampleValidOutputs[i], - fileNames.ratesTrainOutputs[i] ); - } - } - } + const std::vector<std::string> &validationVectorFileList = std::vector<std::string>()); + + /** + * Function used to split all training samples from all images in a set of training and validation. + * \param fileNames + * \param imageList + * \sa SplitTrainingAndValidationSamples + */ + void SplitTrainingToValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList); private: + + /** + * Function used to split training samples in set of training and validation. + * \param image input image + * \param sampleFileName the input sample file name + * \param sampleTrainFileName the input training file name + * \param sampleValidFileName the input validation file name + * \param ratesTrainFileName the rates file name + */ void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, std::string sampleTrainFileName, std::string sampleValidFileName, - std::string ratesTrainFileName) - { - // Split between training and validation - ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read ); - ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite ); - ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite ); - // read sampling rates from ratesTrainOutputs - SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); - rateCalculator->Read( ratesTrainFileName ); - // Compute sampling rates for train and valid - const MapRateType &inputRates = rateCalculator->GetRatesByClass(); - MapRateType trainRates; - MapRateType validRates; - otb::SamplingRateCalculator::TripletType tpt; - for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) - { - double vtr = GetParameterFloat( "sample.vtr" ); - unsigned long total = std::min( it->second.Required, it->second.Tot ); - unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); - unsigned long neededTrain = total - neededValid; - tpt.Tot = total; - tpt.Required = neededTrain; - tpt.Rate = ( 1.0 - vtr ); - trainRates[it->first] = tpt; - tpt.Tot = neededValid; - tpt.Required = neededValid; - tpt.Rate = 1.0; - validRates[it->first] = tpt; - } - - // Use an otb::OGRDataToSamplePositionFilter with 2 outputs - PeriodicSamplerType::SamplerParameterType param; - param.Offset = 0; - param.MaxJitter = 0; - PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); - splitter->SetInput( image ); - splitter->SetOGRData( source ); - splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); - splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); - splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] ); - splitter->SetLayerIndex( 0 ); - splitter->SetOriginFieldName( std::string( "" ) ); - splitter->SetSamplerParameters( param ); - splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); - AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); - splitter->Update(); - } + std::string ratesTrainFileName); protected: + /** Base use for training, this allow to know if the choosed classifier is supervised or unsupervised */ + TrainVectorBase* trainVectorBase; + struct SamplingRates { long int fmt; @@ -473,6 +197,7 @@ protected: * \class TrainFileNamesHandler * This class is used to store file names requires for the application's input and output. * And to clear temporary files generated by the applications + * \ingroup OTBAppClassification */ class TrainFileNamesHandler { @@ -578,5 +303,8 @@ protected: } // end namespace Wrapper } // end namespace otb +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbTrainImagesBase.txx" +#endif #endif //otbTrainImagesBase_h diff --git a/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx b/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx new file mode 100644 index 0000000000000000000000000000000000000000..a0b881ee58c6af146c3bf3b911f0bccfc343e529 --- /dev/null +++ b/Modules/Applications/AppClassification/include/otbTrainImagesBase.txx @@ -0,0 +1,403 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef otbTrainImagesBase_txx +#define otbTrainImagesBase_txx + +#include "otbTrainImagesBase.h" + +namespace otb +{ +namespace Wrapper +{ +void TrainImagesBase::InitIO() +{ + //Group IO + AddParameter( ParameterType_Group, "io", "Input and output data" ); + SetParameterDescription( "io", "This group of parameters allows setting input and output data." ); + + AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" ); + SetParameterDescription( "io.il", "A list of input images." ); + AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" ); + SetParameterDescription( "io.vd", "A list of vector data to select the training samples." ); + MandatoryOff( "io.vd" ); + + AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" ); + EnableParameter( "cleanup" ); + SetParameterDescription( "cleanup", + "If activated, the application will try to clean all temporary files it created" ); + MandatoryOff( "cleanup" ); +} + +void TrainImagesBase::InitSampling() +{ + AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" ); + AddApplication( "MultiImageSamplingRate", "rates", "Sampling rates" ); + AddApplication( "SampleSelection", "select", "Sample selection" ); + AddApplication( "SampleExtraction", "extraction", "Sample extraction" ); + + // Sampling settings + AddParameter( ParameterType_Group, "sample", "Training and validation samples parameters" ); + SetParameterDescription( "sample", + "This group of parameters allows you to set training and validation sample lists parameters." ); + AddParameter( ParameterType_Int, "sample.mt", "Maximum training sample size per class" ); + SetDefaultParameterInt( "sample.mt", 1000 ); + SetParameterDescription( "sample.mt", "Maximum size per class (in pixels) of " + "the training sample list (default = 1000) (no limit = -1). If equal to -1," + " then the maximal size of the available training sample list per class " + "will be equal to the surface area of the smallest class multiplied by the" + " training sample ratio." ); + AddParameter( ParameterType_Int, "sample.mv", "Maximum validation sample size per class" ); + SetDefaultParameterInt( "sample.mv", 1000 ); + SetParameterDescription( "sample.mv", "Maximum size per class (in pixels) of " + "the validation sample list (default = 1000) (no limit = -1). If equal to -1," + " then the maximal size of the available validation sample list per class " + "will be equal to the surface area of the smallest class multiplied by the " + "validation sample ratio." ); + AddParameter( ParameterType_Int, "sample.bm", "Bound sample number by minimum" ); + SetDefaultParameterInt( "sample.bm", 1 ); + SetParameterDescription( "sample.bm", "Bound the number of samples for each " + "class by the number of available samples by the smaller class. Proportions " + "between training and validation are respected. Default is true (=1)." ); + AddParameter( ParameterType_Float, "sample.vtr", "Training and validation sample ratio" ); + SetParameterDescription( "sample.vtr", "Ratio between training and validation samples (0.0 = all training, 1.0 = " + "all validation) (default = 0.5)." ); + SetParameterFloat( "sample.vtr", 0.5, false ); + SetMaximumParameterFloatValue( "sample.vtr", 1.0 ); + SetMinimumParameterFloatValue( "sample.vtr", 0.0 ); + + AddParameter( ParameterType_Float, "sample.percent", "Percentage of sample extract from images" ); + SetParameterDescription( "sample.percent", "Percentage of sample extract from images for " + "training and validation when only images are provided." ); + SetDefaultParameterFloat( "sample.percent", 1.0 ); + SetMinimumParameterFloatValue( "sample.percent", 0.0 ); + SetMaximumParameterFloatValue( "sample.percent", 1.0 ); + + ShareSamplingParameters(); + ConnectSamplingParameters(); +} + +void TrainImagesBase::ShareSamplingParameters() +{ + // hide sampling parameters + //ShareParameter("sample.strategy","rates.strategy"); + //ShareParameter("sample.mim","rates.mim"); + ShareParameter( "ram", "polystat.ram" ); + ShareParameter( "elev", "polystat.elev" ); + ShareParameter( "sample.vfn", "polystat.field" ); +} + +void TrainImagesBase::ConnectSamplingParameters() +{ + Connect( "extraction.field", "polystat.field" ); + Connect( "extraction.layer", "polystat.layer" ); + + Connect( "select.ram", "polystat.ram" ); + Connect( "extraction.ram", "polystat.ram" ); + + Connect( "select.field", "polystat.field" ); + Connect( "select.layer", "polystat.layer" ); + Connect( "select.elev", "polystat.elev" ); + + Connect( "extraction.in", "select.in" ); + Connect( "extraction.vec", "select.out" ); +} + +void TrainImagesBase::InitClassification() +{ + AddApplication( "TrainVectorClassifier", "training", "Model training" ); + trainVectorBase = dynamic_cast<TrainVectorBase*>(GetInternalApplication("training")); + + AddParameter( ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List" ); + SetParameterDescription( "io.valid", "A list of vector data to select the training samples." ); + MandatoryOff( "io.valid" ); + + ShareClassificationParams(); + ConnectClassificationParams(); +}; + +void TrainImagesBase::ShareClassificationParams() +{ + ShareParameter( "io.imstat", "training.io.stats" ); + ShareParameter( "io.out", "training.io.out" ); + + ShareParameter( "classifier", "training.classifier" ); + ShareParameter( "rand", "training.rand" ); + + ShareParameter( "io.confmatout", "training.io.confmatout" ); +} + +void TrainImagesBase::ConnectClassificationParams() +{ + Connect( "training.cfield", "polystat.field" ); + Connect( "select.rand", "training.rand" ); +} + +void TrainImagesBase::ComputePolygonStatistics(FloatVectorImageListType *imageList, + const std::vector<std::string> &vectorFileNames, + const std::vector<std::string> &statisticsFileNames) +{ + unsigned int nbImages = static_cast<unsigned int>(imageList->Size()); + for( unsigned int i = 0; i < nbImages; i++ ) + { + GetInternalApplication( "polystat" )->SetParameterInputImage( "in", imageList->GetNthElement( i ) ); + GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileNames[i], false ); + GetInternalApplication( "polystat" )->SetParameterString( "out", statisticsFileNames[i], false ); + ExecuteInternal( "polystat" ); + } +} + + +TrainImagesBase::SamplingRates TrainImagesBase::ComputeFinalMaximumSamplingRates(bool dedicatedValidation) +{ + SamplingRates rates; + GetInternalApplication( "rates" )->SetParameterString( "mim", "proportional", false ); + double vtr = GetParameterFloat( "sample.vtr" ); + long mt = GetParameterInt( "sample.mt" ); + long mv = GetParameterInt( "sample.mv" ); + // compute final maximum training and final maximum validation + // By default take all samples (-1 means all samples) + rates.fmt = -1; + rates.fmv = -1; + if( GetParameterInt( "sample.bm" ) == 0 ) + { + if( dedicatedValidation ) + { + // fmt and fmv will be used separately + rates.fmt = mt; + rates.fmv = mv; + if( mt > -1 && mv <= -1 && vtr < 0.99999 ) + { + rates.fmv = static_cast<long>(( double ) mt * vtr / ( 1.0 - vtr )); + } + if( mt <= -1 && mv > -1 && vtr > 0.00001 ) + { + rates.fmt = static_cast<long>(( double ) mv * ( 1.0 - vtr ) / vtr); + } + } + else + { + // only fmt will be used for both training and validation samples + // So we try to compute the total number of samples given input + // parameters mt, mv and vtr. + if( mt > -1 && mv > -1 ) + { + rates.fmt = mt + mv; + } + if( mt > -1 && mv <= -1 && vtr < 0.99999 ) + { + rates.fmt = static_cast<long>(( double ) mt / ( 1.0 - vtr )); + } + if( mt <= -1 && mv > -1 && vtr > 0.00001 ) + { + rates.fmt = static_cast<long>(( double ) mv / vtr); + } + } + } + return rates; +} + + +void TrainImagesBase::ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames, + const std::string &ratesFileName, long maximum) +{ + // Sampling rates + GetInternalApplication( "rates" )->SetParameterStringList( "il", statisticsFileNames, false ); + GetInternalApplication( "rates" )->SetParameterString( "out", ratesFileName, false ); + if( GetParameterInt( "sample.bm" ) != 0 ) + { + GetInternalApplication( "rates" )->SetParameterString( "strategy", "smallest", false ); + } + else + { + if( maximum > -1 ) + { + GetInternalApplication( "rates" )->SetParameterString( "strategy", "constant", false ); + GetInternalApplication( "rates" )->SetParameterInt( "strategy.constant.nb", static_cast<int>(maximum), false ); + } + else + { + GetInternalApplication( "rates" )->SetParameterString( "strategy", "all", false ); + } + } + ExecuteInternal( "rates" ); +} + +void +TrainImagesBase::TrainModel(FloatVectorImageListType *imageList, const std::vector<std::string> &sampleTrainFileNames, + const std::vector<std::string> &sampleValidationFileNames) +{ + GetInternalApplication( "training" )->SetParameterStringList( "io.vd", sampleTrainFileNames, false ); + if( !sampleValidationFileNames.empty() ) + GetInternalApplication( "training" )->SetParameterStringList( "valid.vd", sampleValidationFileNames, false ); + + UpdateInternalParameters( "training" ); + // set field names + FloatVectorImageType::Pointer image = imageList->GetNthElement( 0 ); + unsigned int nbBands = image->GetNumberOfComponentsPerPixel(); + std::vector<std::string> selectedNames; + for( unsigned int i = 0; i < nbBands; i++ ) + { + std::ostringstream oss; + oss << i; + selectedNames.push_back( "value_" + oss.str() ); + } + GetInternalApplication( "training" )->SetParameterStringList( "feat", selectedNames, false ); + ExecuteInternal( "training" ); +} + +void TrainImagesBase::SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, + std::string sampleFileName, std::string statisticsFileName, + std::string ratesFileName, SamplingStrategy strategy, + std::string selectedField) +{ + GetInternalApplication( "select" )->SetParameterInputImage( "in", image ); + GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false ); + + // Change the selection strategy based on selected sampling strategy + switch( strategy ) + { + case GEOMETRIC: + GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false ); + GetInternalApplication( "select" )->SetParameterString( "strategy", "percent", false ); + GetInternalApplication( "select" )->SetParameterFloat( "strategy.percent.p", + GetParameterFloat( "sample.percent" ), false ); + break; + case CLASS: + default: + GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false ); + GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false ); + GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false ); + GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 ); + GetInternalApplication( "select" )->SetParameterString( "strategy", "byclass", false ); + GetInternalApplication( "select" )->SetParameterString( "strategy.byclass.in", ratesFileName, false ); + break; + } + + // select sample positions + ExecuteInternal( "select" ); + + GetInternalApplication( "extraction" )->SetParameterString( "vec", sampleFileName, false ); + UpdateInternalParameters( "extraction" ); + if( !selectedField.empty() ) + GetInternalApplication( "extraction" )->SetParameterString( "field", selectedField, false ); + + GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false ); + GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_", false ); + + // extract sample descriptors + ExecuteInternal( "extraction" ); +} + + +void TrainImagesBase::SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, + FloatVectorImageListType *imageList, + std::vector<std::string> vectorFileNames, SamplingStrategy strategy, + std::string selectedFieldName) +{ + + for( unsigned int i = 0; i < imageList->Size(); ++i ) + { + std::string vectorFileName = vectorFileNames.empty() ? "" : vectorFileNames[i]; + SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileName, fileNames.sampleOutputs[i], + fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy, + selectedFieldName ); + } +} + + +void TrainImagesBase::SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, + FloatVectorImageListType *imageList, + const std::vector<std::string> &validationVectorFileList) +{ + for( unsigned int i = 0; i < imageList->Size(); ++i ) + { + SelectAndExtractSamples( imageList->GetNthElement( i ), validationVectorFileList[i], + fileNames.sampleValidOutputs[i], fileNames.polyStatValidOutputs[i], + fileNames.ratesValidOutputs[i], SamplingStrategy::CLASS ); + } +} + +void TrainImagesBase::SplitTrainingToValidationSamples(const TrainFileNamesHandler &fileNames, + FloatVectorImageListType *imageList) +{ + for( unsigned int i = 0; i < imageList->Size(); ++i ) + { + SplitTrainingAndValidationSamples( imageList->GetNthElement( i ), fileNames.sampleOutputs[i], + fileNames.sampleTrainOutputs[i], fileNames.sampleValidOutputs[i], + fileNames.ratesTrainOutputs[i] ); + } +} + +void TrainImagesBase::SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, + std::string sampleTrainFileName, + std::string sampleValidFileName, + std::string ratesTrainFileName) + +{ + // Split between training and validation + ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read ); + ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite ); + ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite ); + // read sampling rates from ratesTrainOutputs + SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); + rateCalculator->Read( ratesTrainFileName ); + // Compute sampling rates for train and valid + const MapRateType &inputRates = rateCalculator->GetRatesByClass(); + MapRateType trainRates; + MapRateType validRates; + otb::SamplingRateCalculator::TripletType tpt; + for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) + { + double vtr = GetParameterFloat( "sample.vtr" ); + unsigned long total = std::min( it->second.Required, it->second.Tot ); + unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); + unsigned long neededTrain = total - neededValid; + tpt.Tot = total; + tpt.Required = neededTrain; + tpt.Rate = ( 1.0 - vtr ); + trainRates[it->first] = tpt; + tpt.Tot = neededValid; + tpt.Required = neededValid; + tpt.Rate = 1.0; + validRates[it->first] = tpt; + } + + // Use an otb::OGRDataToSamplePositionFilter with 2 outputs + PeriodicSamplerType::SamplerParameterType param; + param.Offset = 0; + param.MaxJitter = 0; + PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); + splitter->SetInput( image ); + splitter->SetOGRData( source ); + splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); + splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); + splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] ); + splitter->SetLayerIndex( 0 ); + splitter->SetOriginFieldName( std::string( "" ) ); + splitter->SetSamplerParameters( param ); + splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); + AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); + splitter->Update(); +} +} +} + +#endif diff --git a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx index bd96fe35684b517a5b8f7318938441c3727f5478..0dba5c6ab724588a4fcdfc808c7c5bc6934bfc3b 100644 --- a/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx +++ b/Modules/Applications/AppClassification/include/otbTrainSharkKMeans.txx @@ -1,19 +1,22 @@ -/*========================================================================= - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - - =========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbTrainSharkKMeans_txx #define otbTrainSharkKMeans_txx @@ -34,10 +37,10 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() //MaxNumberOfIterations AddParameter( ParameterType_Int, "classifier.sharkkm.nbmaxiter", "Maximum number of iteration for the kmeans algorithm." ); - SetParameterInt( "classifier.sharkkm.nbmaxiter", 0 ); + SetParameterInt( "classifier.sharkkm.nbmaxiter", 10 ); SetMinimumParameterIntValue( "classifier.sharkkm.nbmaxiter", 0 ); SetParameterDescription( "classifier.sharkkm.nbmaxiter", - "The maximum number of iteration for the kmeans algorithm. Default set to unlimited." ); + "The maximum number of iteration for the kmeans algorithm. 0=unlimited" ); //MaxNumberOfIterations AddParameter( ParameterType_Int, "classifier.sharkkm.k", "The number of class used for the kmeans algorithm." ); diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h index e9d3be1eb5275af9b1cbab7be204d314fb4f6db7..bdd844ca1d3ae63d1d87a6f4764b64daf5ce8d3d 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.h +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.h @@ -1,19 +1,22 @@ -/*========================================================================= - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - - =========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbTrainVectorBase_h #define otbTrainVectorBase_h diff --git a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx index ebb651b324597b6b0bd94b5bcfaeafde7b0b1ddd..82d4f7028ac4fe4e60153cc43eb624c527ee6dc5 100644 --- a/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx +++ b/Modules/Applications/AppClassification/include/otbTrainVectorBase.txx @@ -1,19 +1,22 @@ -/*========================================================================= - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - - =========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbTrainVectorBase_txx #define otbTrainVectorBase_txx @@ -59,7 +62,7 @@ void TrainVectorBase::DoInit() AddParameter(ParameterType_ListView, "feat", "Field names for training features."); SetParameterDescription("feat","List of field names in the input vector data to be used as features for training."); - // Add validation data used to compute confusion matrix or contingence table + // Add validation data used to compute confusion matrix or contingency table AddParameter( ParameterType_Group, "valid", "Validation data" ); SetParameterDescription( "valid", "This group of parameters defines validation data." ); @@ -74,10 +77,24 @@ void TrainVectorBase::DoInit() SetDefaultParameterInt( "valid.layer", 0 ); // Add class field if we used validation - AddParameter(ParameterType_ListView,"cfield","Field containing the class id for supervision"); - SetParameterDescription("cfield","Field containing the class id for supervision. " - "Only geometries with this field available will be taken into account."); - SetListViewSingleSelectionMode("cfield",true); + AddParameter( ParameterType_ListView, "cfield", "Field containing the class id for supervision" ); + SetParameterDescription( "cfield", "Field containing the class id for supervision. " + "Only geometries with this field available will be taken into account." ); + SetListViewSingleSelectionMode( "cfield", true ); + + // Add a new parameter to compute confusion matrix / contingency table + AddParameter( ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix or contingency table" ); + SetParameterDescription( "io.confmatout", "Output file containing the confusion matrix or contingency table (.csv format)." + "The contingency table is ouput when we unsupervised algorithms is used otherwise the confusion matrix is output." ); + MandatoryOff( "io.confmatout" ); + + + // Doc example parameter settings + SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); + SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); + SetDocExampleParameterValue( "io.out", "svmModel.svm" ); + SetDocExampleParameterValue( "feat", "perimeter area width" ); + SetDocExampleParameterValue( "cfield", "predicted" ); // Add parameters for the classifier choice @@ -90,6 +107,9 @@ void TrainVectorBase::DoInit() void TrainVectorBase::DoUpdateParameters() { + LearningApplicationBase::DoUpdateParameters(); + + // if vector data is present and updated then reload fields if( HasValue( "io.vd" ) ) { std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" ); @@ -162,9 +182,32 @@ TrainVectorBase::ExtractTrainingListSamples(const StatisticsMeasurement &measure } TrainVectorBase::ListSamples -TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &itkNotUsed(measurement)) +TrainVectorBase::ExtractClassificationListSamples(const StatisticsMeasurement &measurement) { - return trainingListSamples; + if(GetClassifierCategory() == Supervised) + { + ListSamples tmpListSamples; + ListSamples validationListSamples = ExtractListSamples( "valid.vd", "valid.layer", measurement ); + //Test the input validation set size + if( validationListSamples.labeledListSample->Size() != 0 ) + { + tmpListSamples.listSample = validationListSamples.listSample; + tmpListSamples.labeledListSample = validationListSamples.labeledListSample; + } + else + { + otbAppLogWARNING( + "The validation set is empty. The performance estimation is done using the input training set in this case." ); + tmpListSamples.listSample = trainingListSamples.listSample; + tmpListSamples.labeledListSample = trainingListSamples.labeledListSample; + } + + return tmpListSamples; + } + else + { + return trainingListSamples; + } } diff --git a/Modules/Applications/AppClassification/otb-module.cmake b/Modules/Applications/AppClassification/otb-module.cmake index 322ba7f9c47e9600a73eafee0ade934b66fa9608..8447da5583203782e7028780d80f6cb034e4fad3 100644 --- a/Modules/Applications/AppClassification/otb-module.cmake +++ b/Modules/Applications/AppClassification/otb-module.cmake @@ -11,6 +11,7 @@ otb_module(OTBAppClassification OTBMajorityVoting OTBVectorDataIO OTBSOM + OTBLearningBase OTBSupervised OTBUnsupervised OTBApplicationEngine diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 94e558f6e59ac549e7d84ba6614027b607781661..9dbca30853a9a4a53b1f82630bdc29a9a2b7a273 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -76,6 +76,7 @@ set(bayes_output_format ".bayes") set(rf_output_format ".rf") set(knn_output_format ".knn") set(sharkrf_output_format ".txt") +set(sharkkm_output_format ".txt") # Training algorithms parameters set(libsvm_parameters "-classifier.libsvm.opt" "true" "-classifier.libsvm.prob" "true") @@ -88,7 +89,7 @@ set(bayes_parameters "") set(rf_parameters "") set(knn_parameters "") set(sharkrf_parameters "") - +set(sharkkm_parameters "") # Validation depending on mode set(ascii_comparison --compare-ascii ${EPSILON_6}) @@ -108,7 +109,7 @@ if(OTB_USE_OPENCV) list(APPEND classifierList "BOOST" "DT" "GBT" "ANN" "BAYES" "RF" "KNN") endif() if(OTB_USE_SHARK) - list(APPEND classifierList "SHARKRF") + list(APPEND classifierList "SHARKRF" "SHARKKM") endif() set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF") @@ -224,124 +225,6 @@ foreach(classifier ${classifierList}) endforeach() - -#----------- TrainImagesClustering TESTS ---------------- - -set(sharkkm_output_format ".txt") -set(sharkkm_parameters "") - -if(OTB_USE_SHARK) - list(APPEND clusteringList "SHARKKM") -endif() - -list(APPEND classifier_without_baseline "SHARKKM") - -# Loop on classifiers -foreach(classifier ${clusteringList}) - string(TOLOWER ${classifier} lclassifier) - - # Derive output file name - set(OUTMODELFILE cl${classifier}_ModelQB1${${lclassifier}_output_format}) - set(OUTRASTER cl${classifier}LabeledImageQB1${raster_output_format}) - set(OUTCONFMAP cl${classifier}ConfidenceMapQB1${raster_output_format}) - - list(FIND classifier_without_baseline ${classifier} _classifier_has_baseline) - if(${_classifier_has_baseline} EQUAL -1) - set(valid ${ascii_comparison} ${ascii_ref_path}/${OUTMODELFILE} ${TEMP}/${OUTMODELFILE}) - else() - set(valid "") - endif() - - otb_test_application( - NAME apTvClTrainMethod${classifier}ImagesClassifierQB1 - APP TrainImagesClustering - OPTIONS -io.il ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} - -io.vd ${INPUTDATA}/Classification/VectorData_${${lclassifier}_input}QB1${vector_input_format} - -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} - -classifier ${lclassifier} - ${${lclassifier}_parameters} - -io.out ${TEMP}/${OUTMODELFILE} - -sample.vfn Class - -rand 121212 - - VALID ${valid} - ) - - if(${_classifier_has_baseline} EQUAL -1) - set(valid ${ascii_comparison} ${ascii_ref_path}/${OUTMODELFILE} ${TEMP}/OutXML1_${OUTMODELFILE}) - else() - set(valid "") - endif() - - otb_test_application( - NAME apTvClTrainMethod${classifier}ImagesClassifierQB1_OutXML1 - APP TrainImagesClustering - OPTIONS -io.il ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} - -io.vd ${INPUTDATA}/Classification/VectorData_${${lclassifier}_input}QB1${vector_input_format} - -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} - -classifier ${lclassifier} - ${${lclassifier}_parameters} - -io.out ${TEMP}/OutXML1_${OUTMODELFILE} - -rand 121212 - -sample.vfn Class - -outxml ${TEMP}/cl${classifier}_OutXML1.xml - - VALID ${valid} - ) - - if(${_classifier_has_baseline} EQUAL -1) - set(valid ${ascii_comparison} ${ascii_ref_path}/${OUTMODELFILE} ${TEMP}/OutXML2_${OUTMODELFILE}) - else() - set(valid "") - endif() - - otb_test_application( - NAME apTvClTrainMethod${classifier}ImagesClassifierQB1_InXML1 - APP TrainImagesClustering - OPTIONS -inxml ${INPUTDATA}/cl${classifier}_OutXML1.xml - -io.il ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} - -io.vd ${INPUTDATA}/Classification/VectorData_${${lclassifier}_input}QB1${vector_input_format} - -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} - -io.out ${TEMP}/OutXML2_${OUTMODELFILE} - -sample.vfn Class - VALID ${valid} - ) - - list(FIND classifier_with_confmap ${classifier} _classifier_has_confmap) - if(${_classifier_has_confmap} EQUAL -1) - otb_test_application( - NAME apTvClMethod${classifier}ImageClassifierQB1 - APP ImageClassifier - OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} - -model ${INPUTDATA}/Classification/${OUTMODELFILE} - -imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} - -out ${TEMP}/${OUTRASTER} ${raster_output_option} - - VALID ${raster_comparison} - ${raster_ref_path}/${OUTRASTER} - ${TEMP}/${OUTRASTER} - ) - else() - otb_test_application( - NAME apTvClMethod${classifier}ImageClassifierQB1 - APP ImageClassifier - OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho${raster_input_format} - -model ${INPUTDATA}/Classification/${OUTMODELFILE} - -imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} - -out ${TEMP}/${OUTRASTER} ${raster_output_option} - -confmap ${TEMP}/${OUTCONFMAP} - - VALID ${raster_comparison_two} - ${raster_ref_path}/${OUTRASTER} - ${TEMP}/${OUTRASTER} - ${raster_ref_path}/${OUTCONFMAP} - ${TEMP}/${OUTCONFMAP} - ) - endif() - -endforeach() - - #----------- LIBSVM Classifier TESTS ---------------- if(OTB_USE_LIBSVM) @@ -1063,17 +946,17 @@ if(OTB_USE_OPENCV) ${TEMP}/apTvClTrainVectorClassifierModel.rf) endif() -#----------- TrainVectorClustering TESTS ---------------- +#----------- TrainVectorClassifier unsupervised TESTS ---------------- if(OTB_USE_SHARK) - otb_test_application(NAME apTvClTrainVectorClustering - APP TrainVectorClustering + otb_test_application(NAME apTvClTrainVectorUnsupervised + APP TrainVectorClassifier OPTIONS -io.vd ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite -feat value_0 value_1 value_2 value_3 -classifier sharkkm -io.out ${TEMP}/apTvClTrainVectorClusteringModel.txt) - otb_test_application(NAME apTvClTrainVectorClusteringWithClass - APP TrainVectorClustering + otb_test_application(NAME apTvClTrainVectorUnsupervisedWithClass + APP TrainVectorClassifier OPTIONS -io.vd ${INPUTDATA}/Classification/apTvClSampleExtractionOut.sqlite -feat value_0 value_1 value_2 value_3 -cfield class diff --git a/Modules/Learning/Supervised/include/otbImageClassificationFilter.h b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.h similarity index 85% rename from Modules/Learning/Supervised/include/otbImageClassificationFilter.h rename to Modules/Learning/LearningBase/include/otbImageClassificationFilter.h index 7d4cf8a9766bf21e8cc1a7c2e77c16150720b218..3dbcbfa938af52dfe2dc1ae154120c71f548a40b 100644 --- a/Modules/Learning/Supervised/include/otbImageClassificationFilter.h +++ b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.h @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbImageClassificationFilter_h #define otbImageClassificationFilter_h @@ -34,7 +36,7 @@ namespace otb * \ingroup Streamed * \ingroup Threaded * - * \ingroup OTBSupervised + * \ingroup OTBLearningBase */ template <class TInputImage, class TOutputImage, class TMaskImage = TOutputImage> class ITK_EXPORT ImageClassificationFilter @@ -87,7 +89,7 @@ public: itkSetMacro(BatchMode, bool); itkGetMacro(BatchMode, bool); itkBooleanMacro(BatchMode); - + /** * If set, only pixels within the mask will be classified. * All pixels with a value greater than 0 in the mask, will be classified. diff --git a/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.txx similarity index 93% rename from Modules/Learning/Supervised/include/otbImageClassificationFilter.txx rename to Modules/Learning/LearningBase/include/otbImageClassificationFilter.txx index 54656b1948811c9685eaa01d729bc706ab15e98b..85e531cd5548f5fbf9b49540ff4a91b068faa7f8 100644 --- a/Modules/Learning/Supervised/include/otbImageClassificationFilter.txx +++ b/Modules/Learning/LearningBase/include/otbImageClassificationFilter.txx @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbImageClassificationFilter_txx #define otbImageClassificationFilter_txx diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h index ee84468521b5fc0db1c9d0ac3dcd0d81404b0367..db57e1fdf35f9fddc023597244f013d139377068 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbMachineLearningModel_h #define otbMachineLearningModel_h @@ -61,7 +63,7 @@ namespace otb * \sa ImageClassificationFilter * * - * \ingroup OTBSupervised + * \ingroup OTBLearningBase */ template <class TInputValue, class TTargetValue, class TConfidenceValue = double > class ITK_EXPORT MachineLearningModel diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModel.txx b/Modules/Learning/LearningBase/include/otbMachineLearningModel.txx index 3d67a0a3a2a68000a60225d9b822a25157277f14..8265ef9c58a68c267ec765b0fc7c84640c998217 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModel.txx +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModel.txx @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbMachineLearningModel_txx #define otbMachineLearningModel_txx diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModelFactoryBase.h b/Modules/Learning/LearningBase/include/otbMachineLearningModelFactoryBase.h index 81f771ac5304e7eb58f8f4cc33c49f40f7a550ed..daed6acb7b02b3f6b4733ffe6e869f434664e21d 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModelFactoryBase.h +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModelFactoryBase.h @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbMachineLearningModelFactoryBase_h #define otbMachineLearningModelFactoryBase_h @@ -29,7 +31,7 @@ namespace otb * This class intends to hold the static attributes that can not be * part of a template class (ld error). * - * + * \ingroup OTBLearningBase */ class OTBSupervised_EXPORT MachineLearningModelFactoryBase : public itk::Object { diff --git a/Modules/Learning/LearningBase/include/otbSharkUtils.h b/Modules/Learning/LearningBase/include/otbSharkUtils.h index 08165a114fdf73b67e463ec9e658aefe138932ce..c872dc9f0956bb4b6e08945e8c2503211697cb66 100644 --- a/Modules/Learning/LearningBase/include/otbSharkUtils.h +++ b/Modules/Learning/LearningBase/include/otbSharkUtils.h @@ -1,31 +1,35 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - - =========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbSharkUtils_h #define otbSharkUtils_h -#include "otb_shark.h" #include "itkMacro.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wshadow" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Wsign-compare" #endif -#include <shark/Data/Dataset.h> +#include "otb_shark.h" +#include "shark/Data/Dataset.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif diff --git a/Modules/Learning/LearningBase/otb-module.cmake b/Modules/Learning/LearningBase/otb-module.cmake index 78f6c010994a94ab2c80baf3c477a10c7cb166b6..b73df789bac68fdeb069b523d57eb1a4f1be91ed 100644 --- a/Modules/Learning/LearningBase/otb-module.cmake +++ b/Modules/Learning/LearningBase/otb-module.cmake @@ -5,6 +5,8 @@ otb_module(OTBLearningBase DEPENDS OTBCommon OTBITK + OTBImageIO + OTBImageBase OPTIONAL_DEPENDS OTBShark diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h index b362373e8eb7cad3aca821cbc219444f79b0b1d2..7452193a952fec42e808af8f9cb6da1aca12ae93 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h @@ -18,8 +18,6 @@ #ifndef otbSharkRandomForestsMachineLearningModel_h #define otbSharkRandomForestsMachineLearningModel_h -#include "otb_shark.h" - #include "itkLightObject.h" #include "otbMachineLearningModel.h" @@ -33,6 +31,7 @@ #pragma GCC diagnostic ignored "-Wcast-align" #pragma GCC diagnostic ignored "-Wunknown-pragmas" #endif +#include "otb_shark.h" #include "shark/Algorithms/Trainers/RFTrainer.h" #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx index 0e0b08f3e26dd2d4cb8a0e4bddd7e036d9aae444..b8486999dc6f7d456e5ec6cdade49ba3bc3f91fc 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx @@ -195,6 +195,8 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> { itkExceptionMacro(<< "Error opening " << filename.c_str() ); } + // Add comment with model file name + ofs << "#" << m_RFModel.name() << std::endl; shark::TextOutArchive oa(ofs); m_RFModel.save(oa,0); } @@ -205,8 +207,25 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ::Load(const std::string & filename, const std::string & itkNotUsed(name)) { std::ifstream ifs(filename.c_str()); - shark::TextInArchive ia(ifs); - m_RFModel.load(ia,0); + if( ifs.good() ) + { + // Check if the first line is a comment and verify the name of the model in this case. + std::string line; + getline( ifs, line ); + if( line.at( 0 ) == '#' ) + { + if( line.find( m_RFModel.name() ) == std::string::npos ) + itkExceptionMacro( "The model file : " + filename + " cannot be read." ); + } + else + { + // rewind if first line is not a comment + ifs.clear(); + ifs.seekg( 0, std::ios::beg ); + } + shark::TextInArchive ia( ifs ); + m_RFModel.load( ia, 0 ); + } } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index a31891e5d1d633776c67500ceba2eeb136a9b140..3084b2503e8680fe36b60c0cf70b979bc7656cc9 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -1,25 +1,25 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbSharkKMeansMachineLearningModel_h #define otbSharkKMeansMachineLearningModel_h - - #include "itkLightObject.h" #include "otbMachineLearningModel.h" @@ -45,8 +45,6 @@ #pragma GCC diagnostic pop #endif -using namespace shark; - /** \class SharkKMeansMachineLearningModel * \brief Shark version of Random Forests algorithm * @@ -67,23 +65,24 @@ class ITK_EXPORT SharkKMeansMachineLearningModel : public MachineLearningModel<T { public: /** Standard class typedefs. */ - typedef SharkKMeansMachineLearningModel Self; + typedef SharkKMeansMachineLearningModel Self; typedef MachineLearningModel<TInputValue, TTargetValue> Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; - - typedef typename Superclass::InputValueType InputValueType; - typedef typename Superclass::InputSampleType InputSampleType; - typedef typename Superclass::InputListSampleType InputListSampleType; - typedef typename Superclass::TargetValueType TargetValueType; - typedef typename Superclass::TargetSampleType TargetSampleType; - typedef typename Superclass::TargetListSampleType TargetListSampleType; - typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; - typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; - - typedef HardClusteringModel<RealVector> ClusteringModelType; - typedef ClusteringModelType::OutputType ClusteringOutputType; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; + typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; + typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + + + typedef shark::HardClusteringModel<shark::RealVector> ClusteringModelType; + typedef ClusteringModelType::OutputType ClusteringOutputType; /** Run-time type information (and related methods). */ itkNewMacro( Self ); @@ -108,18 +107,14 @@ public: //@} /** Get the maximum number of iteration for the kMeans algorithm.*/ - itkGetMacro( MaximumNumberOfIterations, unsigned - int ); + itkGetMacro( MaximumNumberOfIterations, unsigned ); /** Set the maximum number of iteration for the kMeans algorithm.*/ - itkSetMacro( MaximumNumberOfIterations, unsigned - int ); + itkSetMacro( MaximumNumberOfIterations, unsigned ); /** Get the number of class for the kMeans algorithm.*/ - itkGetMacro( K, unsigned - int ); + itkGetMacro( K, unsigned ); /** Set the number of class for the kMeans algorithm.*/ - itkSetMacro( K, unsigned - int ); + itkSetMacro( K, unsigned ); /** If true, normalized input data sample list */ itkGetMacro( Normalized, bool ); @@ -154,10 +149,11 @@ private: bool m_Normalized; unsigned int m_K; unsigned int m_MaximumNumberOfIterations; + bool m_CanRead; /** Centroids results form kMeans */ - Centroids centroids; + shark::Centroids m_Centroids; /** shark Model could be SoftClusteringModel or HardClusteringModel */ diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx index e3492437367f81a726c25326134900bf686a150f..267a676ae668a863928be6087146a1ca7dd61d6e 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx @@ -1,25 +1,29 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbSharkKMeansMachineLearningModel_txx #define otbSharkKMeansMachineLearningModel_txx + #include <fstream> #include "itkMacro.h" #include "otbSharkKMeansMachineLearningModel.h" + #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wshadow" @@ -27,15 +31,19 @@ #pragma GCC diagnostic ignored "-Woverloaded-virtual" #pragma GCC diagnostic ignored "-Wignored-qualifiers" #endif -#include <shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h> //normalize -#include <shark/Algorithms/KMeans.h> //k-means algorithm -#include <shark/Models/Clustering/HardClusteringModel.h> -#include <shark/Models/Clustering/SoftClusteringModel.h> -#include <shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h> + +#include "otb_shark.h" +#include "otbSharkUtils.h" +#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" //normalize +#include "shark/Algorithms/KMeans.h" //k-means algorithm +#include "shark/Models/Clustering/HardClusteringModel.h" +#include "shark/Models/Clustering/SoftClusteringModel.h" +#include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" + #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif -#include "otbSharkUtils.h" + namespace otb @@ -43,10 +51,10 @@ namespace otb template<class TInputValue, class TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::SharkKMeansMachineLearningModel() : - m_Normalized( true ), m_K(2), m_MaximumNumberOfIterations( 0 ) + m_Normalized( false ), m_K(2), m_MaximumNumberOfIterations( 10 ) { // Default set HardClusteringModel - m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( ¢roids )); + m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &m_Centroids )); } @@ -63,16 +71,17 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Train() { // Parse input data and convert to Shark Data - std::vector<RealVector> vector_data; - Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); - Data<RealVector> data = createDataFromRange( vector_data ); + std::vector<shark::RealVector> vector_data; + otb::Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); + shark::Data<shark::RealVector> data = shark::createDataFromRange( vector_data ); // Normalized input value if necessary if( m_Normalized ) data = NormalizeData( data ); // Use a Hard Clustering Model for classification - kMeans( data, m_K, centroids, m_MaximumNumberOfIterations ); + shark::kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); + m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &m_Centroids )); } template<class TInputValue, class TOutputValue> @@ -93,7 +102,7 @@ typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::DoPredict(const InputSampleType &value, ConfidenceValueType *quality) const { - RealVector data( value.Size()); + shark::RealVector data( value.Size()); for( size_t i = 0; i < value.Size(); i++ ) { data.push_back( value[i] ); @@ -128,20 +137,30 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> // input list sample and target list sample should be initialized and without assert( input->Size() == targets->Size() && "Input sample list and target label list do not have the same size." ); - assert((( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size())) && - "Quality samples list is not null and does not have the same size as input samples list" ); - if( startIndex + size > input->Size()) + assert( ( ( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size() ) ) && + "Quality samples list is not null and does not have the same size as input samples list" ); + if( startIndex + size > input->Size() ) { itkExceptionMacro( <<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[" ); } // Convert input list of features to shark data format - std::vector<RealVector> features; - Shark::ListSampleRangeToSharkVector( input, features, startIndex, size ); - Data<RealVector> inputSamples = shark::createDataFromRange( features ); + std::vector<shark::RealVector> features; + otb::Shark::ListSampleRangeToSharkVector( input, features, startIndex, size ); + shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange( features ); + + shark::Data<ClusteringOutputType> clusters; + try + { + clusters = ( *m_ClusteringModel )( inputSamples ); + } + catch( ... ) + { + itkExceptionMacro( "Failed to run clustering classification. " + "The number of features of input samples and the model could differ."); + } - Data<ClusteringOutputType> clusters = ( *m_ClusteringModel )( inputSamples ); unsigned int id = startIndex; for( const auto &p : clusters.elements() ) { @@ -159,7 +178,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> quality->SetMeasurementVector( qid, static_cast<ConfidenceValueType>(1.) ); } } - } @@ -173,9 +191,8 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> { itkExceptionMacro( << "Error opening " << filename.c_str()); } + ofs << "#" << m_ClusteringModel->name() << std::endl; shark::TextOutArchive oa( ofs ); - std::string name = m_ClusteringModel->name(); - oa << name; m_ClusteringModel->save( oa, 1 ); } @@ -184,13 +201,22 @@ void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Load(const std::string &filename, const std::string & itkNotUsed( name )) { + m_CanRead = false; std::ifstream ifs( filename.c_str()); + if(ifs.good()) + { + // Check if first line contains model name + std::string line; + std::getline(ifs, line); + m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos; + } + + if(!m_CanRead) + return; + shark::TextInArchive ia( ifs ); - std::string name; - ia >> name; - if(name != m_ClusteringModel->name()) - throw new boost::archive::archive_exception(boost::archive::archive_exception::input_stream_error); - m_ClusteringModel->load( ia, 1 ); + m_ClusteringModel->load( ia, 0 ); + ifs.close(); } template<class TInputValue, class TOutputValue> @@ -200,13 +226,14 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> { try { + m_CanRead = true; this->Load( file ); } catch( ... ) { return false; } - return true; + return m_CanRead; } template<class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h index cf0c033eb5a487fbfcda93a4f2d67677d449fec8..a072d5d71921f24cc483e0193d3f0d04a4ab41d2 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbSharkKMeansMachineLearningModelFactory_h #define otbSharkKMeansMachineLearningModelFactory_h diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.txx index a698e7b1564fa81c8c5d590c4134914a8bd7f898..a4fe5584e04249875de1193e365fd888a43cfa07 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.txx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.txx @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #ifndef otbSharkKMeansMachineLearningModelFactory_txx #define otbSharkKMeansMachineLearningModelFactory_txx diff --git a/Modules/Learning/Unsupervised/test/CMakeLists.txt b/Modules/Learning/Unsupervised/test/CMakeLists.txt index 9411a183d46a984ea3d259fdfb2b1d983c34de3f..8407413cd8020676daea852688a42273077e9f33 100644 --- a/Modules/Learning/Unsupervised/test/CMakeLists.txt +++ b/Modules/Learning/Unsupervised/test/CMakeLists.txt @@ -1,16 +1,11 @@ otb_module_test() set(OTBUnsupervisedTests otbUnsupervisedTestDriver.cxx - otbMachineLearningClusteringModelCanRead.cxx - otbTrainMachineLearningClusteringModel.cxx + otbMachineLearningUnsupervisedModelCanRead.cxx + otbTrainMachineLearningUnsupervisedModel.cxx otbContingencyTableCalculatorTest.cxx ) - -add_executable(otbUnsupervisedTestDriver ${OTBUnsupervisedTests}) -target_link_libraries(otbUnsupervisedTestDriver ${OTBUnsupervised-Test_LIBRARIES}) -otb_module_target_label(otbUnsupervisedTestDriver) - # Tests Declaration otb_add_test(NAME leTuContingencyTableCalculatorNew COMMAND otbUnsupervisedTestDriver @@ -27,5 +22,10 @@ otb_add_test(NAME leTvContingencyTableCalculatorUpdateWithBaseline COMMAND otbUn if(OTB_USE_SHARK) + set(OTBUnsupervisedTests ${OTBUnsupervisedTests} otbSharkUnsupervisedImageClassificationFilter.cxx) include(tests-shark.cmake) endif() + +add_executable(otbUnsupervisedTestDriver ${OTBUnsupervisedTests}) +target_link_libraries(otbUnsupervisedTestDriver ${OTBUnsupervised-Test_LIBRARIES}) +otb_module_target_label(otbUnsupervisedTestDriver) diff --git a/Modules/Learning/Unsupervised/test/otbMachineLearningClusteringModelCanRead.cxx b/Modules/Learning/Unsupervised/test/otbMachineLearningUnsupervisedModelCanRead.cxx similarity index 64% rename from Modules/Learning/Unsupervised/test/otbMachineLearningClusteringModelCanRead.cxx rename to Modules/Learning/Unsupervised/test/otbMachineLearningUnsupervisedModelCanRead.cxx index 761c2716f5743ecaafd46e0cfddb0f189075e6f1..6856f0d0728327c26df9873af3afa016789c6d10 100644 --- a/Modules/Learning/Unsupervised/test/otbMachineLearningClusteringModelCanRead.cxx +++ b/Modules/Learning/Unsupervised/test/otbMachineLearningUnsupervisedModelCanRead.cxx @@ -1,20 +1,22 @@ -/*========================================================================= - - Program: ORFEO Toolbox - Language: C++ - Date: $Date$ - Version: $Revision$ - - - Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. - See OTBCopyright.txt for details. - - - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the above copyright notices for more information. - -=========================================================================*/ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include <iostream> @@ -46,8 +48,8 @@ int otbSharkKMeansMachineLearningModelCanRead(int argc, char *argv[]) return EXIT_FAILURE; } std::string filename( argv[1] ); - typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> RFType; - RFType::Pointer classifier = RFType::New(); + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMType; + KMType::Pointer classifier = KMType::New(); bool lCanRead = classifier->CanReadFile( filename ); if( !lCanRead ) { diff --git a/Modules/Learning/Unsupervised/test/otbSharkUnsupervisedImageClassificationFilter.cxx b/Modules/Learning/Unsupervised/test/otbSharkUnsupervisedImageClassificationFilter.cxx new file mode 100644 index 0000000000000000000000000000000000000000..7c101f7e2eb4ab7b6ec1d083c793b84523d4b25e --- /dev/null +++ b/Modules/Learning/Unsupervised/test/otbSharkUnsupervisedImageClassificationFilter.cxx @@ -0,0 +1,168 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "otbVectorImage.h" +#include "otbImageFileReader.h" +#include "otbImageFileWriter.h" +#include "otbImageClassificationFilter.h" +#include "otbSharkKMeansMachineLearningModelFactory.h" + +#include <random> +#include <chrono> + + +const unsigned int Dimension = 2; +typedef float PixelType; +typedef unsigned short LabeledPixelType; + +typedef otb::VectorImage<PixelType, Dimension> ImageType; +typedef otb::Image<LabeledPixelType, Dimension> LabeledImageType; +typedef otb::ImageClassificationFilter<ImageType, LabeledImageType> ClassificationFilterType; +typedef ClassificationFilterType::ModelType ModelType; +typedef ClassificationFilterType::ValueType ValueType; +typedef ClassificationFilterType::LabelType LabelType; +typedef otb::SharkKMeansMachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; +typedef otb::ImageFileReader<ImageType> ReaderType; +typedef otb::ImageFileReader<LabeledImageType> MaskReaderType; +typedef otb::ImageFileWriter<LabeledImageType> WriterType; + +typedef otb::SharkKMeansMachineLearningModel<PixelType,short unsigned int> MachineLearningModelType; +typedef MachineLearningModelType::InputValueType LocalInputValueType; +typedef MachineLearningModelType::InputSampleType LocalInputSampleType; +typedef MachineLearningModelType::InputListSampleType LocalInputListSampleType; +typedef MachineLearningModelType::TargetValueType LocalTargetValueType; +typedef MachineLearningModelType::TargetSampleType LocalTargetSampleType; +typedef MachineLearningModelType::TargetListSampleType LocalTargetListSampleType; + +void generateSamples(unsigned int num_classes, unsigned int num_samples, + unsigned int num_features, + LocalInputListSampleType * samples, + LocalTargetListSampleType * labels) +{ + std::default_random_engine randomEngine; + std::uniform_int_distribution<int> label_distribution(1,num_classes); + std::uniform_int_distribution<int> feat_distribution(0,256); + for(size_t scount=0; scount<num_samples; ++scount) + { + LabeledPixelType label = label_distribution(randomEngine); + LocalInputSampleType sample(num_features); + for(unsigned int i=0; i<num_features; ++i) + sample[i]= feat_distribution(randomEngine); + samples->SetMeasurementVectorSize(num_features); + samples->PushBack(sample); + labels->PushBack(label); + } +} + +void buildModel(unsigned int num_classes, unsigned int num_samples, + unsigned int num_features, std::string modelfname) +{ + LocalInputListSampleType::Pointer samples = LocalInputListSampleType::New(); + LocalTargetListSampleType::Pointer labels = LocalTargetListSampleType::New(); + + std::cout << "Sample generation\n"; + generateSamples(num_classes, num_samples, num_features, samples, labels); + + MachineLearningModelType::Pointer classifier = MachineLearningModelType::New(); + classifier->SetInputListSample(samples); + classifier->SetTargetListSample(labels); + classifier->SetRegressionMode(false); + classifier->SetK(3); + + std::cout << "Training\n"; + using TimeT = std::chrono::milliseconds; + auto start = std::chrono::system_clock::now(); + classifier->Train(); + auto duration = std::chrono::duration_cast< TimeT> + (std::chrono::system_clock::now() - start); + auto elapsed = duration.count(); + std::cout << "Training took " << elapsed << " ms\n"; + classifier->Save(modelfname); +} + +int otbSharkUnsupervisedImageClassificationFilter(int argc, char * argv[]) +{ + if(argc<5 || argc>7) + { + std::cout << "Usage: input_image output_image output_confidence batchmode [in_model_name] [mask_name]\n"; + } + std::string imfname = argv[1]; + std::string outfname = argv[2]; + std::string conffname = argv[3]; + bool batch = (std::string(argv[4])=="1"); + std::string modelfname = "/tmp/rf_model.txt"; + std::string maskfname{}; + + MaskReaderType::Pointer mask_reader = MaskReaderType::New(); + ReaderType::Pointer reader = ReaderType::New(); + reader->SetFileName(imfname); + reader->UpdateOutputInformation(); + + auto num_features = reader->GetOutput()->GetNumberOfComponentsPerPixel(); + + std::cout << "Image has " << num_features << " bands\n"; + + if(argc>5) + { + modelfname = argv[5]; + } + else + { + buildModel(3, 1000, num_features, modelfname); + } + + ClassificationFilterType::Pointer filter = ClassificationFilterType::New(); + + MachineLearningModelType::Pointer model = MachineLearningModelType::New(); + if(!model->CanReadFile(modelfname)) + { + std::cerr << "Unable to read the model : " << modelfname << "\n"; + return EXIT_FAILURE; + } + + filter->SetModel(model); + filter->SetInput(reader->GetOutput()); + if(argc==7) + { + maskfname = argv[6]; + mask_reader->SetFileName(maskfname); + filter->SetInputMask(mask_reader->GetOutput()); + } + + WriterType::Pointer writer = WriterType::New(); + writer->SetInput(filter->GetOutput()); + writer->SetFileName(outfname); + std::cout << "Classification\n"; + filter->SetBatchMode(batch); + filter->SetUseConfidenceMap(true); + using TimeT = std::chrono::milliseconds; + auto start = std::chrono::system_clock::now(); + writer->Update(); + auto duration = std::chrono::duration_cast< TimeT> + (std::chrono::system_clock::now() - start); + auto elapsed = duration.count(); + std::cout << "Classification took " << elapsed << " ms\n"; + + auto confWriter = otb::ImageFileWriter<ClassificationFilterType::ConfidenceImageType>::New(); + confWriter->SetInput(filter->GetOutputConfidence()); + confWriter->SetFileName(conffname); + confWriter->Update(); + + return EXIT_SUCCESS; +} diff --git a/Modules/Learning/Unsupervised/test/otbTrainMachineLearningClusteringModel.cxx b/Modules/Learning/Unsupervised/test/otbTrainMachineLearningUnsupervisedModel.cxx similarity index 86% rename from Modules/Learning/Unsupervised/test/otbTrainMachineLearningClusteringModel.cxx rename to Modules/Learning/Unsupervised/test/otbTrainMachineLearningUnsupervisedModel.cxx index 7eae88216b1d3da0c7defe2cb2f2fefb642fb4e4..a0805c29fb38defb049a9588a26e10167cb5920b 100644 --- a/Modules/Learning/Unsupervised/test/otbTrainMachineLearningClusteringModel.cxx +++ b/Modules/Learning/Unsupervised/test/otbTrainMachineLearningUnsupervisedModel.cxx @@ -1,3 +1,22 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include <iostream> #include <otbConfigure.h> @@ -154,6 +173,11 @@ int otbSharkKMeansMachineLearningModelPredict(int argc, char *argv[]) KMeansType::Pointer classifier = KMeansType::New(); std::cout << "Load\n"; + if(!classifier->CanReadFile(argv[2])) + { + std::cerr << "Unable to read model file : " << argv[2] << std::endl; + return EXIT_FAILURE; + } classifier->Load( argv[2] ); auto start = std::chrono::system_clock::now(); classifier->SetInputListSample( samples ); diff --git a/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx b/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx index 30b0b20f1904a20aad0865ceae4e944d9c391dc8..876f613508404c96aaa4027978d1e8a7baabc11c 100644 --- a/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx +++ b/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx @@ -1,3 +1,22 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include "otbTestMain.h" void RegisterTests() { @@ -12,5 +31,6 @@ void RegisterTests() REGISTER_TEST(otbSharkKMeansMachineLearningModelNew); REGISTER_TEST(otbSharkKMeansMachineLearningModelTrain); REGISTER_TEST(otbSharkKMeansMachineLearningModelPredict); + REGISTER_TEST(otbSharkUnsupervisedImageClassificationFilter); #endif } diff --git a/Modules/Learning/Unsupervised/test/tests-shark.cmake b/Modules/Learning/Unsupervised/test/tests-shark.cmake index 0635d94ec2a85587e6250958f98c39070b6a7442..df61cb46479f9922e3355441458a1e261879c40c 100644 --- a/Modules/Learning/Unsupervised/test/tests-shark.cmake +++ b/Modules/Learning/Unsupervised/test/tests-shark.cmake @@ -4,22 +4,55 @@ otb_add_test(NAME leTvSharkKMeansMachineLearningModelNew COMMAND otbUnsupervised otbSharkKMeansMachineLearningModelNew ) + otb_add_test(NAME leTvSharkKMeansMachineLearningModel COMMAND otbUnsupervisedTestDriver otbSharkKMeansMachineLearningModelTrain ${INPUTDATA}/letter.scale ${TEMP}/shark_km_model.txt ) -otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanRead COMMAND otbUnsupervisedTestDriver +otb_add_test(NAME otbSharkKMeansMachineLearningModelPredict COMMAND otbUnsupervisedTestDriver otbSharkKMeansMachineLearningModelPredict ${INPUTDATA}/letter.scale + ${INPUTDATA}/Classification/shark_km_model.txt + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanRead COMMAND otbUnsupervisedTestDriver + otbSharkKMeansMachineLearningModelCanRead ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt ) otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanReadFail COMMAND otbUnsupervisedTestDriver - otbSharkKMeansMachineLearningModelPredict - ${INPUTDATA}/letter.scale + otbSharkKMeansMachineLearningModelCanRead ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt ) set_property(TEST leTvSharkKMeansMachineLearningModelCanReadFail PROPERTY WILL_FAIL true) + + + +otb_add_test(NAME leTvImageClassificationFilterSharkKMeans COMMAND otbUnsupervisedTestDriver + --compare-n-images ${NOTOL} 1 + ${BASELINE}/leSharkUnsupervisedImageClassificationFilterOutput.tif + ${TEMP}/leSharkUnsupervisedImageClassificationFilterOutput.tif + otbSharkUnsupervisedImageClassificationFilter + ${INPUTDATA}/Classification/QB_1_ortho.tif + ${TEMP}/leSharkUnsupervisedImageClassificationFilterOutput.tif + ${TEMP}/leSharkUnsupervisedImageClassificationFilterConfidence.tif + 1 + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt + ) + + +otb_add_test(NAME leTvImageClassificationFilterSharkKMeansMask COMMAND otbUnsupervisedTestDriver + --compare-n-images ${NOTOL} 1 + ${BASELINE}/leSharkUnsupervisedImageClassificationFilterWithMaskOutput.tif + ${TEMP}/leSharkUnsupervisedImageClassificationFilterWithMaskOutput.tif + otbSharkUnsupervisedImageClassificationFilter + ${INPUTDATA}/Classification/QB_1_ortho.tif + ${TEMP}/leSharkUnsupervisedImageClassificationFilterWithMaskOutput.tif + ${TEMP}/leSharkUnsupervisedImageClassificationFilterWithMaskConfidence.tif + 1 + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt + ${INPUTDATA}/Classification/QB_1_ortho_mask.tif + )