diff --git a/Modules/Applications/AppClassification/otb-module.cmake b/Modules/Applications/AppClassification/otb-module.cmake index 9c31eeede55e6be11364b288244d7d6a1128ada7..322ba7f9c47e9600a73eafee0ade934b66fa9608 100644 --- a/Modules/Applications/AppClassification/otb-module.cmake +++ b/Modules/Applications/AppClassification/otb-module.cmake @@ -12,6 +12,7 @@ otb_module(OTBAppClassification OTBVectorDataIO OTBSOM OTBSupervised + OTBUnsupervised OTBApplicationEngine OTBIndices OTBMathParser diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModel.h b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h similarity index 100% rename from Modules/Learning/Supervised/include/otbMachineLearningModel.h rename to Modules/Learning/LearningBase/include/otbMachineLearningModel.h diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModel.txx b/Modules/Learning/LearningBase/include/otbMachineLearningModel.txx similarity index 100% rename from Modules/Learning/Supervised/include/otbMachineLearningModel.txx rename to Modules/Learning/LearningBase/include/otbMachineLearningModel.txx diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModelFactoryBase.h b/Modules/Learning/LearningBase/include/otbMachineLearningModelFactoryBase.h similarity index 100% rename from Modules/Learning/Supervised/include/otbMachineLearningModelFactoryBase.h rename to Modules/Learning/LearningBase/include/otbMachineLearningModelFactoryBase.h diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx b/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx index 5e72ce37dbca81c67ad1d33ae72baff2a86436d9..a99aa0f78e4d86f128b855299bfd4eacb340948b 100644 --- a/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx +++ b/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx @@ -37,7 +37,6 @@ #ifdef OTB_USE_SHARK #include "otbSharkRandomForestsMachineLearningModelFactory.h" -#include "otbSharkKMeansMachineLearningModelFactory.h" #endif #include "itkMutexLockHolder.h" @@ -105,7 +104,6 @@ MachineLearningModelFactory<TInputValue,TOutputValue> #ifdef OTB_USE_SHARK RegisterFactory(SharkRandomForestsMachineLearningModelFactory<TInputValue,TOutputValue>::New()); - RegisterFactory(SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue>::New()); #endif #ifdef OTB_USE_OPENCV @@ -162,14 +160,6 @@ MachineLearningModelFactory<TInputValue,TOutputValue> itk::ObjectFactoryBase::UnRegisterFactory(sharkRFFactory); continue; } - - SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> *sharkKMeansFactory = - dynamic_cast<SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> *>(*itFac); - if (sharkKMeansFactory) - { - itk::ObjectFactoryBase::UnRegisterFactory(sharkKMeansFactory); - continue; - } #endif #ifdef OTB_USE_OPENCV diff --git a/Modules/Learning/Supervised/otb-module.cmake b/Modules/Learning/Supervised/otb-module.cmake index ebce0f334e6156c859dbb8d715c6287896e3b7ca..b46c75f574f8b9bb68eabb14e3a704af1e5b8df5 100644 --- a/Modules/Learning/Supervised/otb-module.cmake +++ b/Modules/Learning/Supervised/otb-module.cmake @@ -9,6 +9,7 @@ ENABLE_SHARED OTBCommon OTBITK OTBImageBase + OTBLearningBase OPTIONAL_DEPENDS OTBOpenCV @@ -19,6 +20,7 @@ ENABLE_SHARED OTBTestKernel OTBImageIO OTBImageBase + OTBLearningBase OTBBoost DESCRIPTION diff --git a/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx b/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx index 9c03dcf2fb2b2754ab940a00c54264a167228ae8..e2358fcd1582ab77868e17cc27980928205d9649 100644 --- a/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx +++ b/Modules/Learning/Supervised/test/otbMachineLearningModelCanRead.cxx @@ -319,35 +319,5 @@ int otbSharkRFMachineLearningModelCanRead(int argc, char* argv[]) return EXIT_SUCCESS; } -#include "otbSharkKMeansMachineLearningModel.h" - -int itbSharkKMeansMachineLearningModelCanRead(int argc, char *argv[]) -{ - if( argc != 2 ) - { - std::cerr << "Usage: " << argv[0] << "<model>" << std::endl; - std::cerr << "Called here with " << argc << " arguments\n"; - for( int i = 1; i < argc; ++i ) - { - std::cerr << " - " << argv[i] << "\n"; - } - return EXIT_FAILURE; - } - std::string filename( argv[1] ); - typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> RFType; - RFType::Pointer classifier = RFType::New(); - bool lCanRead = classifier->CanReadFile( filename ); - if( !lCanRead ) - { - std::cerr << "Error otb::SharkKMeansMachineLearningModel : impossible to open the file " << filename << "." - << std::endl; - return EXIT_FAILURE; - } - - return EXIT_SUCCESS; -} - - - #endif diff --git a/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx b/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx index 6aba3c2fc354978a6174c70b255ab96373cd6db8..825341047de890a018c04f90f224d207f5b5d05d 100644 --- a/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx +++ b/Modules/Learning/Supervised/test/otbSupervisedTestDriver.cxx @@ -62,9 +62,6 @@ void RegisterTests() REGISTER_TEST(otbSharkRFMachineLearningModel); REGISTER_TEST(otbSharkRFMachineLearningModelCanRead); REGISTER_TEST(otbSharkImageClassificationFilter); - REGISTER_TEST(otbSharkKMeansMachineLearningModelNew); - REGISTER_TEST(otbSharkKMeansMachineLearningModelTrain); - REGISTER_TEST(otbSharkKMeansMachineLearningModelPredict); #endif REGISTER_TEST(otbImageClassificationFilterNew); diff --git a/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx b/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx index 6221a7dcdc0674c1263c2ae3a1f0276d47116d08..fc1597c3f4128e8b620c9060270a93fc62b444bb 100644 --- a/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx +++ b/Modules/Learning/Supervised/test/otbTrainMachineLearningModel.cxx @@ -1287,84 +1287,4 @@ int otbSharkRFMachineLearningModel(int argc, char * argv[]) } -#include "otbSharkKMeansMachineLearningModel.h" - -int otbSharkKMeansMachineLearningModelNew(int itkNotUsed( argc ), char *itkNotUsed( argv )[]) -{ - typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> SharkRFType; - SharkRFType::Pointer classifier = SharkRFType::New(); - return EXIT_SUCCESS; -} - -int otbSharkKMeansMachineLearningModelTrain(int argc, char *argv[]) -{ - if( argc != 3 ) - { - std::cout << "Wrong number of arguments " << std::endl; - std::cout << "Usage : sample file, output file " << std::endl; - return EXIT_FAILURE; - } - - typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType; - InputListSampleType::Pointer samples = InputListSampleType::New(); - TargetListSampleType::Pointer labels = TargetListSampleType::New(); - - if( !SharkReadDataFile( argv[1], samples, labels ) ) - { - std::cout << "Failed to read samples file " << argv[1] << std::endl; - return EXIT_FAILURE; - } - - KMeansType::Pointer classifier = KMeansType::New(); - classifier->SetInputListSample( samples ); - classifier->SetTargetListSample( labels ); - classifier->SetRegressionMode( false ); - classifier->SetK( 3 ); - classifier->SetMaximumNumberOfIterations( 0 ); - std::cout << "Train\n"; - classifier->Train(); - std::cout << "Save\n"; - classifier->Save( argv[2] ); - - return EXIT_SUCCESS; -} - - -int otbSharkKMeansMachineLearningModelPredict(int argc, char *argv[]) -{ - if( argc != 3 ) - { - std::cout << "Wrong number of arguments " << std::endl; - std::cout << "Usage : sample file, input model file " << std::endl; - return EXIT_FAILURE; - } - - - typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType; - InputListSampleType::Pointer samples = InputListSampleType::New(); - TargetListSampleType::Pointer labels = TargetListSampleType::New(); - - if( !SharkReadDataFile( argv[1], samples, labels ) ) - { - std::cout << "Failed to read samples file " << argv[1] << std::endl; - return EXIT_FAILURE; - } - - KMeansType::Pointer classifier = KMeansType::New(); - std::cout << "Load\n"; - classifier->Load( argv[2] ); - auto start = std::chrono::system_clock::now(); - classifier->SetInputListSample( samples ); - classifier->SetTargetListSample( labels ); - std::cout << "Predict loaded\n"; - classifier->PredictBatch( samples, NULL ); - using TimeT = std::chrono::milliseconds; - auto duration = std::chrono::duration_cast<TimeT>( std::chrono::system_clock::now() - start ); - auto elapsed = duration.count(); - std::cout << "PredictAll took " << elapsed << " ms\n"; - - return EXIT_SUCCESS; -} - - #endif diff --git a/Modules/Learning/Supervised/test/tests-shark.cmake b/Modules/Learning/Supervised/test/tests-shark.cmake index 73706dda293b073e91680b1cbbdcd154895949eb..49ac03632b94d7d15d9f7dcadbd2d92158d47947 100644 --- a/Modules/Learning/Supervised/test/tests-shark.cmake +++ b/Modules/Learning/Supervised/test/tests-shark.cmake @@ -60,30 +60,3 @@ otb_add_test(NAME leTvImageClassificationFilterSharkFastMask COMMAND otbSupervi ${INPUTDATA}/Classification/QB_1_ortho_mask.tif ) - - -# kMeans Shark related tests - -otb_add_test(NAME leTvSharkKMeansMachineLearningModelNew COMMAND otbSupervisedTestDriver - otbSharkKMeansMachineLearningModelNew - ) - -otb_add_test(NAME leTvSharkKMeansMachineLearningModel COMMAND otbSupervisedTestDriver - otbSharkKMeansMachineLearningModelTrain - ${INPUTDATA}/letter.scale - ${TEMP}/shark_km_model.txt - ) - -otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanRead COMMAND otbSupervisedTestDriver - otbSharkKMeansMachineLearningModelPredict - ${INPUTDATA}/letter.scale - ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt - ) - -otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanReadFail COMMAND otbSupervisedTestDriver - otbSharkKMeansMachineLearningModelPredict - ${INPUTDATA}/letter.scale - ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt - ) - -set_property(TEST leTvSharkKMeansMachineLearningModelCanReadFail PROPERTY WILL_FAIL true) diff --git a/Modules/Learning/Unsupervised/CMakeLists.txt b/Modules/Learning/Unsupervised/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e83c272cd617b09367dce57c8ed038757904d14c --- /dev/null +++ b/Modules/Learning/Unsupervised/CMakeLists.txt @@ -0,0 +1,4 @@ +project(OTBUnsupervised) + + +otb_module_impl() diff --git a/Modules/Learning/Unsupervised/include/otbMachineLearningClusteringModelFactory.h b/Modules/Learning/Unsupervised/include/otbMachineLearningClusteringModelFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..62b0c6add6797f1be9854f7a8157ff82f1d50ce8 --- /dev/null +++ b/Modules/Learning/Unsupervised/include/otbMachineLearningClusteringModelFactory.h @@ -0,0 +1,81 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbMachineLearningClusteringModelFactory_h +#define otbMachineLearningClusteringModelFactory_h + +#include "otbMachineLearningModel.h" +#include "otbMachineLearningModelFactoryBase.h" + +namespace otb +{ +/** \class MachineLearningModelFactory + * \brief Creation of object instance using object factory. + * + * \ingroup OTBUnsupervised + */ +template <class TInputValue, class TOutputValue> +class MachineLearningModelFactory : public MachineLearningModelFactoryBase +{ +public: + /** Standard class typedefs. */ + typedef MachineLearningModelFactory Self; + typedef itk::Object Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Class Methods used to interface with the registered factories */ + + /** Run-time type information (and related methods). */ + itkTypeMacro(MachineLearningModelFactory, itk::Object); + + /** Convenient typedefs. */ + typedef otb::MachineLearningModel<TInputValue,TOutputValue> MachineLearningModelType; + typedef typename MachineLearningModelType::Pointer MachineLearningModelTypePointer; + + /** Mode in which the files is intended to be used */ + typedef enum { ReadMode, WriteMode } FileModeType; + + /** Create the appropriate MachineLearningModel depending on the particulars of the file. */ + static MachineLearningModelTypePointer CreateMachineLearningModel(const std::string& path, FileModeType mode); + + static void CleanFactories(); + +protected: + MachineLearningModelFactory(); + ~MachineLearningModelFactory() ITK_OVERRIDE; + +private: + MachineLearningModelFactory(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + /** Register Built-in factories */ + static void RegisterBuiltInFactories(); + + /** Register a single factory, ensuring it has not been registered + * twice */ + static void RegisterFactory(itk::ObjectFactoryBase * factory); + +}; + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbMachineLearningClusteringModelFactory.txx" +#endif + +#endif //otbMachineLearningClusteringModelFactory_h diff --git a/Modules/Learning/Unsupervised/include/otbMachineLearningClusteringModelFactory.txx b/Modules/Learning/Unsupervised/include/otbMachineLearningClusteringModelFactory.txx new file mode 100644 index 0000000000000000000000000000000000000000..9e06486f78376dda92a1cf64afa9992fb3c7e6e5 --- /dev/null +++ b/Modules/Learning/Unsupervised/include/otbMachineLearningClusteringModelFactory.txx @@ -0,0 +1,134 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbMachineLearningModelFactory_txx +#define otbMachineLearningModelFactory_txx + +#include "otbMachineLearningClusteringModelFactory.h" +#include "otbConfigure.h" + +#ifdef OTB_USE_SHARK +#include "otbSharkKMeansMachineLearningModelFactory.h" +#endif + +#include "itkMutexLockHolder.h" + + +namespace otb +{ +template <class TInputValue, class TOutputValue> +typename MachineLearningModel<TInputValue,TOutputValue>::Pointer +MachineLearningModelFactory<TInputValue,TOutputValue> +::CreateMachineLearningModel(const std::string& path, FileModeType mode) +{ + RegisterBuiltInFactories(); + + std::list<MachineLearningModelTypePointer> possibleMachineLearningModel; + std::list<LightObject::Pointer> allobjects = + itk::ObjectFactoryBase::CreateAllInstance("otbMachineLearningModel"); + for(std::list<LightObject::Pointer>::iterator i = allobjects.begin(); + i != allobjects.end(); ++i) + { + MachineLearningModel<TInputValue,TOutputValue> * io = dynamic_cast<MachineLearningModel<TInputValue,TOutputValue>*>(i->GetPointer()); + if(io) + { + possibleMachineLearningModel.push_back(io); + } + else + { + std::cerr << "Error MachineLearningModel Factory did not return an MachineLearningModel: " + << (*i)->GetNameOfClass() + << std::endl; + } + } + for(typename std::list<MachineLearningModelTypePointer>::iterator k = possibleMachineLearningModel.begin(); + k != possibleMachineLearningModel.end(); ++k) + { + if( mode == ReadMode ) + { + if((*k)->CanReadFile(path)) + { + return *k; + } + } + else if( mode == WriteMode ) + { + if((*k)->CanWriteFile(path)) + { + return *k; + } + + } + } + return ITK_NULLPTR; +} + +template <class TInputValue, class TOutputValue> +void +MachineLearningModelFactory<TInputValue,TOutputValue> +::RegisterBuiltInFactories() +{ + itk::MutexLockHolder<itk::SimpleMutexLock> lockHolder(mutex); + +#ifdef OTB_USE_SHARK + RegisterFactory(SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue>::New()); +#endif + +} + +template <class TInputValue, class TOutputValue> +void +MachineLearningModelFactory<TInputValue,TOutputValue> +::RegisterFactory(itk::ObjectFactoryBase * factory) +{ + // Unregister any previously registered factory of the same class + // Might be more intensive but static bool is not an option due to + // ld error. + itk::ObjectFactoryBase::UnRegisterFactory(factory); + itk::ObjectFactoryBase::RegisterFactory(factory); +} + +template <class TInputValue, class TOutputValue> +void +MachineLearningModelFactory<TInputValue,TOutputValue> +::CleanFactories() +{ + itk::MutexLockHolder<itk::SimpleMutexLock> lockHolder(mutex); + + std::list<itk::ObjectFactoryBase*> factories = itk::ObjectFactoryBase::GetRegisteredFactories(); + std::list<itk::ObjectFactoryBase*>::iterator itFac; + + for (itFac = factories.begin(); itFac != factories.end() ; ++itFac) + { + +#ifdef OTB_USE_SHARK + SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> *sharkKMeansFactory = + dynamic_cast<SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> *>(*itFac); + if (sharkKMeansFactory) + { + itk::ObjectFactoryBase::UnRegisterFactory(sharkKMeansFactory); + continue; + } +#endif + + } + +} + +} // end namespace otb + +#endif diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h similarity index 97% rename from Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.h rename to Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index c822b029d16d19b7869b250ea8464873ff88c062..a31891e5d1d633776c67500ceba2eeb136a9b140 100644 --- a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -18,9 +18,7 @@ #ifndef otbSharkKMeansMachineLearningModel_h #define otbSharkKMeansMachineLearningModel_h -#include <shark/Models/Clustering/HardClusteringModel.h> -#include <shark/Models/Clustering/SoftClusteringModel.h> -#include "otb_shark.h" + #include "itkLightObject.h" #include "otbMachineLearningModel.h" @@ -36,6 +34,9 @@ #pragma GCC diagnostic ignored "-Wunknown-pragmas" #endif +#include "otb_shark.h" +#include "shark/Models/Clustering/HardClusteringModel.h" +#include "shark/Models/Clustering/SoftClusteringModel.h" #include "shark/Models/Clustering/Centroids.h" #include "shark/Models/Clustering/ClusteringModel.h" #include "shark/Algorithms/KMeans.h" @@ -57,7 +58,7 @@ using namespace shark; * For more information, see * http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html * - * \ingroup OTBSupervised + * \ingroup OTBUnsupervised */ namespace otb { diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx similarity index 100% rename from Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.txx rename to Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h similarity index 98% rename from Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.h rename to Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h index 2d439f0b926deb41f5b97077ca89cf1000d7d9eb..cf0c033eb5a487fbfcda93a4f2d67677d449fec8 100644 --- a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.h @@ -26,7 +26,7 @@ namespace otb /** \class SharkKMeansMachineLearningModelFactory * \brief Creation of an instance of a SharkKMeansMachineLearningModel object using the object factory * - * \ingroup OTBSupervised + * \ingroup OTBUnsupervised */ template <class TInputValue, class TTargetValue> class ITK_EXPORT SharkKMeansMachineLearningModelFactory : public itk::ObjectFactoryBase diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.txx similarity index 100% rename from Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.txx rename to Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModelFactory.txx diff --git a/Modules/Learning/Unsupervised/otb-module.cmake b/Modules/Learning/Unsupervised/otb-module.cmake new file mode 100644 index 0000000000000000000000000000000000000000..d849a226116fc612935e97002430f69e3b263028 --- /dev/null +++ b/Modules/Learning/Unsupervised/otb-module.cmake @@ -0,0 +1,24 @@ +set(DOCUMENTATION "This module provides the Orfeo Toolbox unsupervised +classification and regression framework, currently based on Shark") + +otb_module(OTBUnsupervised + DEPENDS + OTBCommon + OTBITK + OTBImageBase + OTBLearningBase + OTBSupervised + + OPTIONAL_DEPENDS + OTBShark + + TEST_DEPENDS + OTBTestKernel + OTBImageIO + OTBImageBase + OTBLearningBase + OTBSupervised + + DESCRIPTION + "${DOCUMENTATION}" + ) diff --git a/Modules/Learning/Unsupervised/test/CMakeLists.txt b/Modules/Learning/Unsupervised/test/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1880a72167918927ac8b5cbc02686f769fa3dc32 --- /dev/null +++ b/Modules/Learning/Unsupervised/test/CMakeLists.txt @@ -0,0 +1,17 @@ +otb_module_test() +set(OTBUnsupervisedTests + otbUnsupervisedTestDriver.cxx + otbMachineLearningClusteringModelCanRead.cxx + otbTrainMachineLearningClusteringModel.cxx + ) + + +add_executable(otbUnsupervisedTestDriver ${OTBUnsupervisedTests}) +target_link_libraries(otbUnsupervisedTestDriver ${OTBUnsupervised-Test_LIBRARIES}) +otb_module_target_label(otbUnsupervisedTestDriver) + +# Tests Declaration + +if(OTB_USE_SHARK) + include(tests-shark.cmake) +endif() diff --git a/Modules/Learning/Unsupervised/test/otbMachineLearningClusteringModelCanRead.cxx b/Modules/Learning/Unsupervised/test/otbMachineLearningClusteringModelCanRead.cxx new file mode 100644 index 0000000000000000000000000000000000000000..761c2716f5743ecaafd46e0cfddb0f189075e6f1 --- /dev/null +++ b/Modules/Learning/Unsupervised/test/otbMachineLearningClusteringModelCanRead.cxx @@ -0,0 +1,62 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ + +#include <iostream> + +#include <otbConfigure.h> +#include <otbMachineLearningModel.h> + +typedef otb::MachineLearningModel<float,short> MachineLearningModelType; +typedef MachineLearningModelType::InputValueType InputValueType; +typedef MachineLearningModelType::InputSampleType InputSampleType; +typedef MachineLearningModelType::InputListSampleType InputListSampleType; +typedef MachineLearningModelType::TargetValueType TargetValueType; +typedef MachineLearningModelType::TargetSampleType TargetSampleType; +typedef MachineLearningModelType::TargetListSampleType TargetListSampleType; + +#ifdef OTB_USE_SHARK + +#include "otbSharkKMeansMachineLearningModel.h" + +int otbSharkKMeansMachineLearningModelCanRead(int argc, char *argv[]) +{ + if( argc != 2 ) + { + std::cerr << "Usage: " << argv[0] << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for( int i = 1; i < argc; ++i ) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename( argv[1] ); + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> RFType; + RFType::Pointer classifier = RFType::New(); + bool lCanRead = classifier->CanReadFile( filename ); + if( !lCanRead ) + { + std::cerr << "Error otb::SharkKMeansMachineLearningModel : impossible to open the file " << filename << "." + << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +#endif diff --git a/Modules/Learning/Unsupervised/test/otbTrainMachineLearningClusteringModel.cxx b/Modules/Learning/Unsupervised/test/otbTrainMachineLearningClusteringModel.cxx new file mode 100644 index 0000000000000000000000000000000000000000..7eae88216b1d3da0c7defe2cb2f2fefb642fb4e4 --- /dev/null +++ b/Modules/Learning/Unsupervised/test/otbTrainMachineLearningClusteringModel.cxx @@ -0,0 +1,170 @@ +#include <iostream> + +#include <otbConfigure.h> +#include <otbMachineLearningModel.h> + +typedef otb::MachineLearningModel<float,short> MachineLearningModelType; +typedef MachineLearningModelType::InputValueType InputValueType; +typedef MachineLearningModelType::InputSampleType InputSampleType; +typedef MachineLearningModelType::InputListSampleType InputListSampleType; +typedef MachineLearningModelType::TargetValueType TargetValueType; +typedef MachineLearningModelType::TargetSampleType TargetSampleType; +typedef MachineLearningModelType::TargetListSampleType TargetListSampleType; + +typedef otb::MachineLearningModel<float,float> MachineLearningModelRegressionType; +typedef MachineLearningModelRegressionType::InputValueType InputValueRegressionType; +typedef MachineLearningModelRegressionType::InputSampleType InputSampleRegressionType; +typedef MachineLearningModelRegressionType::InputListSampleType InputListSampleRegressionType; +typedef MachineLearningModelRegressionType::TargetValueType TargetValueRegressionType; +typedef MachineLearningModelRegressionType::TargetSampleType TargetSampleRegressionType; +typedef MachineLearningModelRegressionType::TargetListSampleType TargetListSampleRegressionType; + + +#ifdef OTB_USE_SHARK +#include "otbSharkKMeansMachineLearningModel.h" +#include "otb_boost_string_header.h" +#include <chrono> + +bool SharkReadDataFile(const std::string & infname, InputListSampleType * samples, TargetListSampleType * labels) +{ + std::ifstream ifs(infname.c_str()); + + if(!ifs) + { + std::cout<<"Could not read file "<<infname<<std::endl; + return false; + } + + unsigned int nbfeatures = 0; + + std::string line; + while (std::getline(ifs, line)) + { + boost::algorithm::trim(line); + + if(nbfeatures == 0) + { + nbfeatures = std::count(line.begin(),line.end(),' '); + } + + if(line.size()>1) + { + InputSampleType sample(nbfeatures); + sample.Fill(0); + + std::string::size_type pos = line.find_first_of(" ", 0); + + // Parse label + TargetSampleType label; + label[0] = std::stoi(line.substr(0, pos).c_str()); + + bool endOfLine = false; + unsigned int id = 0; + + while(!endOfLine) + { + std::string::size_type nextpos = line.find_first_of(" ", pos+1); + + if(pos == std::string::npos) + { + endOfLine = true; + nextpos = line.size()-1; + } + else + { + std::string feature = line.substr(pos,nextpos-pos); + std::string::size_type semicolonpos = feature.find_first_of(":"); + id = std::stoi(feature.substr(0,semicolonpos).c_str()); + sample[id - 1] = atof(feature.substr(semicolonpos+1,feature.size()-semicolonpos).c_str()); + pos = nextpos; + } + + } + samples->SetMeasurementVectorSize(itk::NumericTraits<InputSampleType>::GetLength(sample)); + samples->PushBack(sample); + labels->PushBack(label); + } + } + + //std::cout<<"Retrieved "<<samples->Size()<<" samples"<<std::endl; + ifs.close(); + return true; +} + + +int otbSharkKMeansMachineLearningModelNew(int itkNotUsed( argc ), char *itkNotUsed( argv )[]) +{ + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> SharkRFType; + SharkRFType::Pointer classifier = SharkRFType::New(); + return EXIT_SUCCESS; +} + +int otbSharkKMeansMachineLearningModelTrain(int argc, char *argv[]) +{ + if( argc != 3 ) + { + std::cout << "Wrong number of arguments " << std::endl; + std::cout << "Usage : sample file, output file " << std::endl; + return EXIT_FAILURE; + } + + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType; + InputListSampleType::Pointer samples = InputListSampleType::New(); + TargetListSampleType::Pointer labels = TargetListSampleType::New(); + + if( !SharkReadDataFile( argv[1], samples, labels ) ) + { + std::cout << "Failed to read samples file " << argv[1] << std::endl; + return EXIT_FAILURE; + } + + KMeansType::Pointer classifier = KMeansType::New(); + classifier->SetInputListSample( samples ); + classifier->SetTargetListSample( labels ); + classifier->SetRegressionMode( false ); + classifier->SetK( 3 ); + classifier->SetMaximumNumberOfIterations( 0 ); + std::cout << "Train\n"; + classifier->Train(); + std::cout << "Save\n"; + classifier->Save( argv[2] ); + + return EXIT_SUCCESS; +} + + +int otbSharkKMeansMachineLearningModelPredict(int argc, char *argv[]) +{ + if( argc != 3 ) + { + std::cout << "Wrong number of arguments " << std::endl; + std::cout << "Usage : sample file, input model file " << std::endl; + return EXIT_FAILURE; + } + + typedef otb::SharkKMeansMachineLearningModel<InputValueType, TargetValueType> KMeansType; + InputListSampleType::Pointer samples = InputListSampleType::New(); + TargetListSampleType::Pointer labels = TargetListSampleType::New(); + + if( !SharkReadDataFile( argv[1], samples, labels ) ) + { + std::cout << "Failed to read samples file " << argv[1] << std::endl; + return EXIT_FAILURE; + } + + KMeansType::Pointer classifier = KMeansType::New(); + std::cout << "Load\n"; + classifier->Load( argv[2] ); + auto start = std::chrono::system_clock::now(); + classifier->SetInputListSample( samples ); + classifier->SetTargetListSample( labels ); + std::cout << "Predict loaded\n"; + classifier->PredictBatch( samples, NULL ); + using TimeT = std::chrono::milliseconds; + auto duration = std::chrono::duration_cast<TimeT>( std::chrono::system_clock::now() - start ); + auto elapsed = duration.count(); + std::cout << "PredictAll took " << elapsed << " ms\n"; + + return EXIT_SUCCESS; +} +#endif diff --git a/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx b/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx new file mode 100644 index 0000000000000000000000000000000000000000..14ca633f1cfaecbd96a348b26c7c37df1fc46f7c --- /dev/null +++ b/Modules/Learning/Unsupervised/test/otbUnsupervisedTestDriver.cxx @@ -0,0 +1,10 @@ +#include "otbTestMain.h" +void RegisterTests() +{ +#ifdef OTB_USE_SHARK + REGISTER_TEST(otbSharkKMeansMachineLearningModelCanRead); + REGISTER_TEST(otbSharkKMeansMachineLearningModelNew); + REGISTER_TEST(otbSharkKMeansMachineLearningModelTrain); + REGISTER_TEST(otbSharkKMeansMachineLearningModelPredict); +#endif +} diff --git a/Modules/Learning/Unsupervised/test/tests-shark.cmake b/Modules/Learning/Unsupervised/test/tests-shark.cmake new file mode 100644 index 0000000000000000000000000000000000000000..0635d94ec2a85587e6250958f98c39070b6a7442 --- /dev/null +++ b/Modules/Learning/Unsupervised/test/tests-shark.cmake @@ -0,0 +1,25 @@ +# kMeans Shark related tests + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelNew COMMAND otbUnsupervisedTestDriver + otbSharkKMeansMachineLearningModelNew + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModel COMMAND otbUnsupervisedTestDriver + otbSharkKMeansMachineLearningModelTrain + ${INPUTDATA}/letter.scale + ${TEMP}/shark_km_model.txt + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanRead COMMAND otbUnsupervisedTestDriver + otbSharkKMeansMachineLearningModelPredict + ${INPUTDATA}/letter.scale + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_KMeansmodel.txt + ) + +otb_add_test(NAME leTvSharkKMeansMachineLearningModelCanReadFail COMMAND otbUnsupervisedTestDriver + otbSharkKMeansMachineLearningModelPredict + ${INPUTDATA}/letter.scale + ${INPUTDATA}/Classification/otbSharkImageClassificationFilter_RFmodel.txt + ) + +set_property(TEST leTvSharkKMeansMachineLearningModelCanReadFail PROPERTY WILL_FAIL true)