From 47da27464db136301b624098007686111ae182fd Mon Sep 17 00:00:00 2001
From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr>
Date: Tue, 14 Feb 2017 11:23:29 +0100
Subject: [PATCH] TEST: Add Shark KMeans tests.

---
 .../AppClassification/test/CMakeLists.txt     |  5 ++
 .../test/otbMachineLearningModelCanRead.cxx   | 31 +++++++
 .../test/otbSupervisedTestDriver.cxx          |  5 +-
 .../test/otbTrainMachineLearningModel.cxx     | 81 +++++++++++++++++++
 .../Supervised/test/tests-shark.cmake         | 38 +++++++--
 5 files changed, 154 insertions(+), 6 deletions(-)

diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt
index aaa2f9cfe1..8455a03eb3 100644
--- a/Modules/Applications/AppClassification/test/CMakeLists.txt
+++ b/Modules/Applications/AppClassification/test/CMakeLists.txt
@@ -76,6 +76,7 @@ set(bayes_output_format ".bayes")
 set(rf_output_format ".rf")
 set(knn_output_format ".knn")
 set(sharkrf_output_format ".txt")
+set(sharkkm_output_format ".txt")
 
 # Training algorithms parameters
 set(libsvm_parameters "-classifier.libsvm.opt" "true" "-classifier.libsvm.prob" "true")
@@ -88,6 +89,8 @@ set(bayes_parameters "")
 set(rf_parameters "")
 set(knn_parameters "")
 set(sharkrf_parameters "")
+set(sharkkm_parameters "")
+
 
 # Validation depending on mode
 set(ascii_comparison --compare-ascii ${EPSILON_6})
@@ -108,6 +111,7 @@ if(OTB_USE_OPENCV)
 endif()
 if(OTB_USE_SHARK)
   list(APPEND classifierList "SHARKRF")
+  list(APPEND classifierList "SHARKKM")
 endif()
 
 set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF")
@@ -115,6 +119,7 @@ set(classifier_with_confmap "LIBSVM" "BOOST" "KNN" "ANN" "RF")
 # This is a black list for classifier that can not have a baseline
 # because they are using randomness and seed can not be set
 set(classifier_without_baseline "SHARKRF")
+set(classifier_without_baseline "SHARKKM")
 
 # Loop on classifiers
 foreach(classifier ${classifierList})
diff --git a/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx b/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx
index b302542e81..9c03dcf2fb 100644
--- a/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx
+++ b/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx
@@ -319,4 +319,35 @@ int otbSharkRFMachineLearningModelCanRead(int argc, char* argv[])
   return EXIT_SUCCESS;
 }
 
+#include "otbSharkKMeansMachineLearningModel.h"
+
+int itbSharkKMeansMachineLearningModelCanRead(int argc, char *argv[])
+{
+  if( argc != 2 )
+    {
+    std::cerr << "Usage: " << argv[0] << "<model>" << std::endl;
+    std::cerr << "Called here with " << argc << " arguments\n";
+    for( int i = 1; i < argc; ++i )
+      {
+      std::cerr << " - " << argv[i] << "\n";
+      }
+    return EXIT_FAILURE;
+    }
+  std::string filename( argv[1] );
+  typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> RFType;
+  RFType::Pointer classifier = RFType::New();
+  bool lCanRead = classifier->CanReadFile( filename );
+  if( !lCanRead )
+    {
+    std::cerr << "Error otb::SharkKMeansMachineLearningModel : impossible to open the file " << filename << "."
+              << std::endl;
+    return EXIT_FAILURE;
+    }
+
+  return EXIT_SUCCESS;
+}
+
+
+
+
 #endif
diff --git a/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx b/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx
index 9bc337bf9d..6aba3c2fc3 100644
--- a/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx
+++ b/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx
@@ -62,8 +62,11 @@ void RegisterTests()
   REGISTER_TEST(otbSharkRFMachineLearningModel);
   REGISTER_TEST(otbSharkRFMachineLearningModelCanRead);
   REGISTER_TEST(otbSharkImageClassificationFilter);
+  REGISTER_TEST(otbSharkKMeansMachineLearningModelNew);
+  REGISTER_TEST(otbSharkKMeansMachineLearningModelTrain);
+  REGISTER_TEST(otbSharkKMeansMachineLearningModelPredict);
 #endif
-  
+
   REGISTER_TEST(otbImageClassificationFilterNew);
   REGISTER_TEST(otbImageClassificationFilter);
 }
diff --git a/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx b/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx
index 518a301e4e..6221a7dcdc 100644
--- a/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx
+++ b/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx
@@ -1286,4 +1286,85 @@ int otbSharkRFMachineLearningModel(int argc, char * argv[])
    return EXIT_SUCCESS;
 }
 
+
+#include "otbSharkKMeansMachineLearningModel.h"
+
+int otbSharkKMeansMachineLearningModelNew(int itkNotUsed( argc ), char *itkNotUsed( argv )[])
+{
+  typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> SharkRFType;
+  SharkRFType::Pointer classifier = SharkRFType::New();
+  return EXIT_SUCCESS;
+}
+
+int otbSharkKMeansMachineLearningModelTrain(int argc, char *argv[])
+{
+  if( argc != 3 )
+    {
+    std::cout << "Wrong number of arguments " << std::endl;
+    std::cout << "Usage : sample file, output file " << std::endl;
+    return EXIT_FAILURE;
+    }
+
+  typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType;
+  InputListSampleType::Pointer samples = InputListSampleType::New();
+  TargetListSampleType::Pointer labels = TargetListSampleType::New();
+
+  if( !SharkReadDataFile( argv[1], samples, labels ) )
+    {
+    std::cout << "Failed to read samples file " << argv[1] << std::endl;
+    return EXIT_FAILURE;
+    }
+
+  KMeansType::Pointer classifier = KMeansType::New();
+  classifier->SetInputListSample( samples );
+  classifier->SetTargetListSample( labels );
+  classifier->SetRegressionMode( false );
+  classifier->SetK( 3 );
+  classifier->SetMaximumNumberOfIterations( 0 );
+  std::cout << "Train\n";
+  classifier->Train();
+  std::cout << "Save\n";
+  classifier->Save( argv[2] );
+
+  return EXIT_SUCCESS;
+}
+
+
+int otbSharkKMeansMachineLearningModelPredict(int argc, char *argv[])
+{
+  if( argc != 3 )
+    {
+    std::cout << "Wrong number of arguments " << std::endl;
+    std::cout << "Usage : sample file, input model file " << std::endl;
+    return EXIT_FAILURE;
+    }
+
+
+  typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType;
+  InputListSampleType::Pointer samples = InputListSampleType::New();
+  TargetListSampleType::Pointer labels = TargetListSampleType::New();
+
+  if( !SharkReadDataFile( argv[1], samples, labels ) )
+    {
+    std::cout << "Failed to read samples file " << argv[1] << std::endl;
+    return EXIT_FAILURE;
+    }
+
+  KMeansType::Pointer classifier = KMeansType::New();
+  std::cout << "Load\n";
+  classifier->Load( argv[2] );
+  auto start = std::chrono::system_clock::now();
+  classifier->SetInputListSample( samples );
+  classifier->SetTargetListSample( labels );
+  std::cout << "Predict loaded\n";
+  classifier->PredictBatch( samples, NULL );
+  using TimeT = std::chrono::milliseconds;
+  auto duration = std::chrono::duration_cast<TimeT>( std::chrono::system_clock::now() - start );
+  auto elapsed = duration.count();
+  std::cout << "PredictAll took " << elapsed << " ms\n";
+
+  return EXIT_SUCCESS;
+}
+
+
 #endif
diff --git a/Modules/Learning/Supervised/test/tests-shark.cmake b/Modules/Learning/Supervised/test/tests-shark.cmake
index 265fe840ae..73706dda29 100644
--- a/Modules/Learning/Supervised/test/tests-shark.cmake
+++ b/Modules/Learning/Supervised/test/tests-shark.cmake
@@ -14,17 +14,17 @@ otb_add_test(NAME leTvSharkRFMachineLearningModelCanRead COMMAND otbSupervisedTe
 
 otb_add_test(NAME leTvSharkRFMachineLearningModelCanReadFail COMMAND otbSupervisedTestDriver
   otbSharkRFMachineLearningModelCanRead
-  ${INPUTDATA}/ROI_QB_MUL_4_svmModel.txt
+  ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt
   )
 set_property(TEST leTvSharkRFMachineLearningModelCanReadFail PROPERTY WILL_FAIL true)
 
 
 otb_add_test(NAME leTvImageClassificationFilterSharkFast COMMAND  otbSupervisedTestDriver
-  --compare-n-images ${NOTOL} 2 
+  --compare-n-images ${NOTOL} 2
   ${BASELINE}/leSharkImageClassificationFilterOutput.tif
   ${TEMP}/leSharkImageClassificationFilterOutput.tif
   ${BASELINE}/leSharkImageClassificationFilterConfidence.tif
-  ${TEMP}/leSharkImageClassificationFilterConfidence.tif   
+  ${TEMP}/leSharkImageClassificationFilterConfidence.tif
   otbSharkImageClassificationFilter
   ${INPUTDATA}/Classification/QB_1_ortho.tif
   ${TEMP}/leSharkImageClassificationFilterOutput.tif
@@ -46,11 +46,11 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFast COMMAND  otbSupervisedT
 #   )
 
 otb_add_test(NAME leTvImageClassificationFilterSharkFastMask COMMAND  otbSupervisedTestDriver
-  --compare-n-images ${NOTOL} 2 
+  --compare-n-images ${NOTOL} 2
   ${BASELINE}/leSharkImageClassificationFilterWithMaskOutput.tif
   ${TEMP}/leSharkImageClassificationFilterWithMaskOutput.tif
   ${BASELINE}/leSharkImageClassificationFilterWithMaskConfidence.tif
-  ${TEMP}/leSharkImageClassificationFilterWithMaskConfidence.tif   
+  ${TEMP}/leSharkImageClassificationFilterWithMaskConfidence.tif
   otbSharkImageClassificationFilter
   ${INPUTDATA}/Classification/QB_1_ortho.tif
   ${TEMP}/leSharkImageClassificationFilterWithMaskOutput.tif
@@ -59,3 +59,31 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFastMask COMMAND  otbSupervi
   ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt
   ${INPUTDATA}/Classification/QB_1_ortho_mask.tif
   )
+
+
+
+# kMeans Shark related tests
+
+otb_add_test(NAME leTvSharkKMeansMachineLearningModelNew COMMAND otbSupervisedTestDriver
+  otbSharkKMeansMachineLearningModelNew
+  )
+
+otb_add_test(NAME leTvSharkKMeansMachineLearningModel COMMAND otbSupervisedTestDriver
+  otbSharkKMeansMachineLearningModelTrain
+  ${INPUTDATA}/letter.scale
+  ${TEMP}/shark_km_model.txt
+  )
+
+otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanRead COMMAND otbSupervisedTestDriver
+  otbSharkKMeansMachineLearningModelPredict
+  ${INPUTDATA}/letter.scale
+  ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt
+  )
+
+otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanReadFail COMMAND otbSupervisedTestDriver
+  otbSharkKMeansMachineLearningModelPredict
+  ${INPUTDATA}/letter.scale
+  ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt
+  )
+
+set_property(TEST leTvSharkKMeansMachineLearningModelCanReadFail PROPERTY WILL_FAIL true)
-- 
GitLab