From 47da27464db136301b624098007686111ae182fd Mon Sep 17 00:00:00 2001 From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr> Date: Tue, 14 Feb 2017 11:23:29 +0100 Subject: [PATCH] TEST: Add Shark KMeans tests. --- .../AppClassification/test/CMakeLists.txt | 5 ++ .../test/otbMachineLearningModelCanRead.cxx | 31 +++++++ .../test/otbSupervisedTestDriver.cxx | 5 +- .../test/otbTrainMachineLearningModel.cxx | 81 +++++++++++++++++++ .../Supervised/test/tests-shark.cmake | 38 +++++++-- 5 files changed, 154 insertions(+), 6 deletions(-) diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index aaa2f9cfe1..8455a03eb3 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -76,6 +76,7 @@ set(bayes_output_format ".bayes") set(rf_output_format ".rf") set(knn_output_format ".knn") set(sharkrf_output_format ".txt") +set(sharkkm_output_format ".txt") # Training algorithms parameters set(libsvm_parameters "-classifier.libsvm.opt" "true" "-classifier.libsvm.prob" "true") @@ -88,6 +89,8 @@ set(bayes_parameters "") set(rf_parameters "") set(knn_parameters "") set(sharkrf_parameters "") +set(sharkkm_parameters "") + # Validation depending on mode set(ascii_comparison --compare-ascii ${EPSILON_6}) @@ -108,6 +111,7 @@ if(OTB_USE_OPENCV) endif() if(OTB_USE_SHARK) list(APPEND classifierList "SHARKRF") + list(APPEND classifierList "SHARKKM") endif() set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF") @@ -115,6 +119,7 @@ set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF") # This is a black list for classifier that can not have a baseline # because they are using randomness and seed can not be set set(classifier_without_baseline "SHARKRF") +set(classifier_without_baseline "SHARKKM") # Loop on classifiers foreach(classifier ${classifierList}) diff --git a/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx b/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx index b302542e81..9c03dcf2fb 100644 --- a/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx +++ b/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx @@ -319,4 +319,35 @@ int otbSharkRFMachineLearningModelCanRead(int argc, char* argv[]) return EXIT_SUCCESS; } +#include "otbSharkKMeansMachineLearningModel.h" + +int itbSharkKMeansMachineLearningModelCanRead(int argc, char *argv[]) +{ + if( argc != 2 ) + { + std::cerr << "Usage: " << argv[0] << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for( int i = 1; i < argc; ++i ) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename( argv[1] ); + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> RFType; + RFType::Pointer classifier = RFType::New(); + bool lCanRead = classifier->CanReadFile( filename ); + if( !lCanRead ) + { + std::cerr << "Error otb::SharkKMeansMachineLearningModel : impossible to open the file " << filename << "." + << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + + + + #endif diff --git a/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx b/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx index 9bc337bf9d..6aba3c2fc3 100644 --- a/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx +++ b/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx @@ -62,8 +62,11 @@ void RegisterTests() REGISTER_TEST(otbSharkRFMachineLearningModel); REGISTER_TEST(otbSharkRFMachineLearningModelCanRead); REGISTER_TEST(otbSharkImageClassificationFilter); + REGISTER_TEST(otbSharkKMeansMachineLearningModelNew); + REGISTER_TEST(otbSharkKMeansMachineLearningModelTrain); + REGISTER_TEST(otbSharkKMeansMachineLearningModelPredict); #endif - + REGISTER_TEST(otbImageClassificationFilterNew); REGISTER_TEST(otbImageClassificationFilter); } diff --git a/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx b/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx index 518a301e4e..6221a7dcdc 100644 --- a/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx +++ b/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx @@ -1286,4 +1286,85 @@ int otbSharkRFMachineLearningModel(int argc, char * argv[]) return EXIT_SUCCESS; } + +#include "otbSharkKMeansMachineLearningModel.h" + +int otbSharkKMeansMachineLearningModelNew(int itkNotUsed( argc ), char *itkNotUsed( argv )[]) +{ + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> SharkRFType; + SharkRFType::Pointer classifier = SharkRFType::New(); + return EXIT_SUCCESS; +} + +int otbSharkKMeansMachineLearningModelTrain(int argc, char *argv[]) +{ + if( argc != 3 ) + { + std::cout << "Wrong number of arguments " << std::endl; + std::cout << "Usage : sample file, output file " << std::endl; + return EXIT_FAILURE; + } + + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType; + InputListSampleType::Pointer samples = InputListSampleType::New(); + TargetListSampleType::Pointer labels = TargetListSampleType::New(); + + if( !SharkReadDataFile( argv[1], samples, labels ) ) + { + std::cout << "Failed to read samples file " << argv[1] << std::endl; + return EXIT_FAILURE; + } + + KMeansType::Pointer classifier = KMeansType::New(); + classifier->SetInputListSample( samples ); + classifier->SetTargetListSample( labels ); + classifier->SetRegressionMode( false ); + classifier->SetK( 3 ); + classifier->SetMaximumNumberOfIterations( 0 ); + std::cout << "Train\n"; + classifier->Train(); + std::cout << "Save\n"; + classifier->Save( argv[2] ); + + return EXIT_SUCCESS; +} + + +int otbSharkKMeansMachineLearningModelPredict(int argc, char *argv[]) +{ + if( argc != 3 ) + { + std::cout << "Wrong number of arguments " << std::endl; + std::cout << "Usage : sample file, input model file " << std::endl; + return EXIT_FAILURE; + } + + + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType; + InputListSampleType::Pointer samples = InputListSampleType::New(); + TargetListSampleType::Pointer labels = TargetListSampleType::New(); + + if( !SharkReadDataFile( argv[1], samples, labels ) ) + { + std::cout << "Failed to read samples file " << argv[1] << std::endl; + return EXIT_FAILURE; + } + + KMeansType::Pointer classifier = KMeansType::New(); + std::cout << "Load\n"; + classifier->Load( argv[2] ); + auto start = std::chrono::system_clock::now(); + classifier->SetInputListSample( samples ); + classifier->SetTargetListSample( labels ); + std::cout << "Predict loaded\n"; + classifier->PredictBatch( samples, NULL ); + using TimeT = std::chrono::milliseconds; + auto duration = std::chrono::duration_cast<TimeT>( std::chrono::system_clock::now() - start ); + auto elapsed = duration.count(); + std::cout << "PredictAll took " << elapsed << " ms\n"; + + return EXIT_SUCCESS; +} + + #endif diff --git a/Modules/Learning/Supervised/test/tests-shark.cmake b/Modules/Learning/Supervised/test/tests-shark.cmake index 265fe840ae..73706dda29 100644 --- a/Modules/Learning/Supervised/test/tests-shark.cmake +++ b/Modules/Learning/Supervised/test/tests-shark.cmake @@ -14,17 +14,17 @@ otb_add_test(NAME leTvSharkRFMachineLearningModelCanRead COMMAND otbSupervisedTe otb_add_test(NAME leTvSharkRFMachineLearningModelCanReadFail COMMAND otbSupervisedTestDriver otbSharkRFMachineLearningModelCanRead - ${INPUTDATA}/ROI_QB_MUL_4_svmModel.txt + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt ) set_property(TEST leTvSharkRFMachineLearningModelCanReadFail PROPERTY WILL_FAIL true) otb_add_test(NAME leTvImageClassificationFilterSharkFast COMMAND otbSupervisedTestDriver - --compare-n-images ${NOTOL} 2 + --compare-n-images ${NOTOL} 2 ${BASELINE}/leSharkImageClassificationFilterOutput.tif ${TEMP}/leSharkImageClassificationFilterOutput.tif ${BASELINE}/leSharkImageClassificationFilterConfidence.tif - ${TEMP}/leSharkImageClassificationFilterConfidence.tif + ${TEMP}/leSharkImageClassificationFilterConfidence.tif otbSharkImageClassificationFilter ${INPUTDATA}/Classification/QB_1_ortho.tif ${TEMP}/leSharkImageClassificationFilterOutput.tif @@ -46,11 +46,11 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFast COMMAND otbSupervisedT # ) otb_add_test(NAME leTvImageClassificationFilterSharkFastMask COMMAND otbSupervisedTestDriver - --compare-n-images ${NOTOL} 2 + --compare-n-images ${NOTOL} 2 ${BASELINE}/leSharkImageClassificationFilterWithMaskOutput.tif ${TEMP}/leSharkImageClassificationFilterWithMaskOutput.tif ${BASELINE}/leSharkImageClassificationFilterWithMaskConfidence.tif - ${TEMP}/leSharkImageClassificationFilterWithMaskConfidence.tif + ${TEMP}/leSharkImageClassificationFilterWithMaskConfidence.tif otbSharkImageClassificationFilter ${INPUTDATA}/Classification/QB_1_ortho.tif ${TEMP}/leSharkImageClassificationFilterWithMaskOutput.tif @@ -59,3 +59,31 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFastMask COMMAND otbSupervi ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt ${INPUTDATA}/Classification/QB_1_ortho_mask.tif ) + + + +# kMeans Shark related tests + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelNew COMMAND otbSupervisedTestDriver + otbSharkKMeansMachineLearningModelNew + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModel COMMAND otbSupervisedTestDriver + otbSharkKMeansMachineLearningModelTrain + ${INPUTDATA}/letter.scale + ${TEMP}/shark_km_model.txt + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanRead COMMAND otbSupervisedTestDriver + otbSharkKMeansMachineLearningModelPredict + ${INPUTDATA}/letter.scale + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanReadFail COMMAND otbSupervisedTestDriver + otbSharkKMeansMachineLearningModelPredict + ${INPUTDATA}/letter.scale + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt + ) + +set_property(TEST leTvSharkKMeansMachineLearningModelCanReadFail PROPERTY WILL_FAIL true) -- GitLab