diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModel.h b/Modules/Learning/Supervised/include/otbMachineLearningModel.h index 2cf4850182d5b8d6f34f84695a8390e2bd153cb4..ee84468521b5fc0db1c9d0ac3dcd0d81404b0367 100644 --- a/Modules/Learning/Supervised/include/otbMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbMachineLearningModel.h @@ -57,6 +57,7 @@ namespace otb * \sa NormalBayesMachineLearningModel * \sa NeuralNetworkMachineLearningModel * \sa SharkRandomForestsMachineLearningModel + * \sa SharkKMeansMachineLearningModel * \sa ImageClassificationFilter * * @@ -90,7 +91,7 @@ public: //@} /**\name Confidence value typedef */ - typedef TConfidenceValue ConfidenceValueType; + typedef TConfidenceValue ConfidenceValueType; typedef itk::FixedArray<ConfidenceValueType,1> ConfidenceSampleType; typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType; diff --git a/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx b/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx index a99aa0f78e4d86f128b855299bfd4eacb340948b..5e72ce37dbca81c67ad1d33ae72baff2a86436d9 100644 --- a/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx +++ b/Modules/Learning/Supervised/include/otbMachineLearningModelFactory.txx @@ -37,6 +37,7 @@ #ifdef OTB_USE_SHARK #include "otbSharkRandomForestsMachineLearningModelFactory.h" +#include "otbSharkKMeansMachineLearningModelFactory.h" #endif #include "itkMutexLockHolder.h" @@ -104,6 +105,7 @@ MachineLearningModelFactory<TInputValue,TOutputValue> #ifdef OTB_USE_SHARK RegisterFactory(SharkRandomForestsMachineLearningModelFactory<TInputValue,TOutputValue>::New()); + RegisterFactory(SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue>::New()); #endif #ifdef OTB_USE_OPENCV @@ -160,6 +162,14 @@ MachineLearningModelFactory<TInputValue,TOutputValue> itk::ObjectFactoryBase::UnRegisterFactory(sharkRFFactory); continue; } + + SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> *sharkKMeansFactory = + dynamic_cast<SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> *>(*itFac); + if (sharkKMeansFactory) + { + itk::ObjectFactoryBase::UnRegisterFactory(sharkKMeansFactory); + continue; + } #endif #ifdef OTB_USE_OPENCV diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.h new file mode 100644 index 0000000000000000000000000000000000000000..c822b029d16d19b7869b250ea8464873ff88c062 --- /dev/null +++ b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.h @@ -0,0 +1,174 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbSharkKMeansMachineLearningModel_h +#define otbSharkKMeansMachineLearningModel_h + +#include <shark/Models/Clustering/HardClusteringModel.h> +#include <shark/Models/Clustering/SoftClusteringModel.h> +#include "otb_shark.h" + +#include "itkLightObject.h" +#include "otbMachineLearningModel.h" + +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wshadow" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#pragma GCC diagnostic ignored "-Wsign-compare" +#pragma GCC diagnostic ignored "-Wcast-align" +#pragma GCC diagnostic ignored "-Wunknown-pragmas" +#endif + +#include "shark/Models/Clustering/Centroids.h" +#include "shark/Models/Clustering/ClusteringModel.h" +#include "shark/Algorithms/KMeans.h" + +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif + +using namespace shark; + +/** \class SharkKMeansMachineLearningModel + * \brief Shark version of Random Forests algorithm + * + * This is a specialization of MachineLearningModel class allowing to + * use Shark implementation of the Random Forests algorithm. + * + * It is noteworthy that training step is parallel. + * + * For more information, see + * http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html + * + * \ingroup OTBSupervised + */ +namespace otb +{ +template<class TInputValue, class TTargetValue> +class ITK_EXPORT SharkKMeansMachineLearningModel : public MachineLearningModel<TInputValue, TTargetValue> +{ +public: + /** Standard class typedefs. */ + typedef SharkKMeansMachineLearningModel Self; + typedef MachineLearningModel<TInputValue, TTargetValue> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; + typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; + typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + + typedef HardClusteringModel<RealVector> ClusteringModelType; + typedef ClusteringModelType::OutputType ClusteringOutputType; + + /** Run-time type information (and related methods). */ + itkNewMacro( Self ); + itkTypeMacro( SharkKMeansMachineLearningModel, MachineLearningModel ); + + /** Train the machine learning model */ + virtual void Train() ITK_OVERRIDE; + + /** Save the model to file */ + virtual void Save(const std::string &filename, const std::string &name = "") ITK_OVERRIDE; + + /** Load the model from file */ + virtual void Load(const std::string &filename, const std::string &name = "") ITK_OVERRIDE; + + /**\name Classification model file compatibility tests */ + //@{ + /** Is the input model file readable and compatible with the corresponding classifier ? */ + virtual bool CanReadFile(const std::string &) ITK_OVERRIDE; + + /** Is the input model file writable and compatible with the corresponding classifier ? */ + virtual bool CanWriteFile(const std::string &) ITK_OVERRIDE; + //@} + + /** Get the maximum number of iteration for the kMeans algorithm.*/ + itkGetMacro( MaximumNumberOfIterations, unsigned + int ); + /** Set the maximum number of iteration for the kMeans algorithm.*/ + itkSetMacro( MaximumNumberOfIterations, unsigned + int ); + + /** Get the number of class for the kMeans algorithm.*/ + itkGetMacro( K, unsigned + int ); + /** Set the number of class for the kMeans algorithm.*/ + itkSetMacro( K, unsigned + int ); + + /** If true, normalized input data sample list */ + itkGetMacro( Normalized, bool ); + itkSetMacro( Normalized, bool ); + +protected: + /** Constructor */ + SharkKMeansMachineLearningModel(); + + /** Destructor */ + virtual ~SharkKMeansMachineLearningModel(); + + /** Predict values using the model */ + virtual TargetSampleType + DoPredict(const InputSampleType &input, ConfidenceValueType *quality = ITK_NULLPTR) const ITK_OVERRIDE; + + + virtual void DoPredictBatch(const InputListSampleType *, const unsigned int &startIndex, const unsigned int &size, + TargetListSampleType *, ConfidenceListSampleType * = ITK_NULLPTR) const ITK_OVERRIDE; + + template<typename DataType> + DataType NormalizeData(const DataType &data) const; + + /** PrintSelf method */ + void PrintSelf(std::ostream &os, itk::Indent indent) const; + +private: + SharkKMeansMachineLearningModel(const Self &); //purposely not implemented + void operator=(const Self &); //purposely not implemented + + // Parameters set by the user + bool m_Normalized; + unsigned int m_K; + unsigned int m_MaximumNumberOfIterations; + + + /** Centroids results form kMeans */ + Centroids centroids; + + + /** shark Model could be SoftClusteringModel or HardClusteringModel */ + boost::shared_ptr<ClusteringModelType> m_ClusteringModel; + +}; +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION + +#include "otbSharkKMeansMachineLearningModel.txx" + +#endif + +#endif diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.txx new file mode 100644 index 0000000000000000000000000000000000000000..e3492437367f81a726c25326134900bf686a150f --- /dev/null +++ b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModel.txx @@ -0,0 +1,230 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbSharkKMeansMachineLearningModel_txx +#define otbSharkKMeansMachineLearningModel_txx +#include <fstream> +#include "itkMacro.h" +#include "otbSharkKMeansMachineLearningModel.h" +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wshadow" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#pragma GCC diagnostic ignored "-Woverloaded-virtual" +#pragma GCC diagnostic ignored "-Wignored-qualifiers" +#endif +#include <shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h> //normalize +#include <shark/Algorithms/KMeans.h> //k-means algorithm +#include <shark/Models/Clustering/HardClusteringModel.h> +#include <shark/Models/Clustering/SoftClusteringModel.h> +#include <shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h> +#if defined(__GNUC__) || defined(__clang__) +#pragma GCC diagnostic pop +#endif +#include "otbSharkUtils.h" + + +namespace otb +{ +template<class TInputValue, class TOutputValue> +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::SharkKMeansMachineLearningModel() : + m_Normalized( true ), m_K(2), m_MaximumNumberOfIterations( 0 ) +{ + // Default set HardClusteringModel + m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( ¢roids )); +} + + +template<class TInputValue, class TOutputValue> +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::~SharkKMeansMachineLearningModel() +{ +} + +/** Train the machine learning model */ +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::Train() +{ + // Parse input data and convert to Shark Data + std::vector<RealVector> vector_data; + Shark::ListSampleToSharkVector( this->GetInputListSample(), vector_data ); + Data<RealVector> data = createDataFromRange( vector_data ); + + // Normalized input value if necessary + if( m_Normalized ) + data = NormalizeData( data ); + + // Use a Hard Clustering Model for classification + kMeans( data, m_K, centroids, m_MaximumNumberOfIterations ); +} + +template<class TInputValue, class TOutputValue> +template<typename DataType> +DataType +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::NormalizeData(const DataType &data) const +{ + shark::Normalizer<> normalizer; + shark::NormalizeComponentsUnitVariance<> normalizingTrainer( true );//zero mean + normalizingTrainer.train( normalizer, data ); + return normalizer( data ); +} + +template<class TInputValue, class TOutputValue> +typename SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::TargetSampleType +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::DoPredict(const InputSampleType &value, ConfidenceValueType *quality) const +{ + RealVector data( value.Size()); + for( size_t i = 0; i < value.Size(); i++ ) + { + data.push_back( value[i] ); + } + + // Change quality measurement only if SoftClustering or other clustering method is used. + if( quality != ITK_NULLPTR ) + { + //unsigned int probas = (*m_ClusteringModel)( data ); + ( *quality ) = ConfidenceValueType( 1.); + } + + TargetSampleType target; + ClusteringOutputType predictedValue = (*m_ClusteringModel)( data ); + target[0] = static_cast<TOutputValue>(predictedValue); + return target; +} + +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::DoPredictBatch(const InputListSampleType *input, + const unsigned int &startIndex, + const unsigned int &size, + TargetListSampleType *targets, + ConfidenceListSampleType *quality) const +{ + + // Perform check on input values + assert( input != ITK_NULLPTR ); + assert( targets != ITK_NULLPTR ); + + // input list sample and target list sample should be initialized and without + assert( input->Size() == targets->Size() && "Input sample list and target label list do not have the same size." ); + assert((( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size())) && + "Quality samples list is not null and does not have the same size as input samples list" ); + if( startIndex + size > input->Size()) + { + itkExceptionMacro( + <<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[" ); + } + + // Convert input list of features to shark data format + std::vector<RealVector> features; + Shark::ListSampleRangeToSharkVector( input, features, startIndex, size ); + Data<RealVector> inputSamples = shark::createDataFromRange( features ); + + Data<ClusteringOutputType> clusters = ( *m_ClusteringModel )( inputSamples ); + unsigned int id = startIndex; + for( const auto &p : clusters.elements() ) + { + TargetSampleType target; + target[0] = static_cast<TOutputValue>(p); + targets->SetMeasurementVector( id, target ); + ++id; + } + + // Change quality measurement only if SoftClustering or other clustering method is used. + if( quality != ITK_NULLPTR ) + { + for( unsigned int qid = startIndex; qid < size; ++qid ) + { + quality->SetMeasurementVector( qid, static_cast<ConfidenceValueType>(1.) ); + } + } + +} + + +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::Save(const std::string &filename, const std::string & itkNotUsed( name )) +{ + std::ofstream ofs( filename.c_str()); + if( !ofs ) + { + itkExceptionMacro( << "Error opening " << filename.c_str()); + } + shark::TextOutArchive oa( ofs ); + std::string name = m_ClusteringModel->name(); + oa << name; + m_ClusteringModel->save( oa, 1 ); +} + +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::Load(const std::string &filename, const std::string & itkNotUsed( name )) +{ + std::ifstream ifs( filename.c_str()); + shark::TextInArchive ia( ifs ); + std::string name; + ia >> name; + if(name != m_ClusteringModel->name()) + throw new boost::archive::archive_exception(boost::archive::archive_exception::input_stream_error); + m_ClusteringModel->load( ia, 1 ); +} + +template<class TInputValue, class TOutputValue> +bool +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::CanReadFile(const std::string &file) +{ + try + { + this->Load( file ); + } + catch( ... ) + { + return false; + } + return true; +} + +template<class TInputValue, class TOutputValue> +bool +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::CanWriteFile(const std::string & itkNotUsed( file )) +{ + return true; +} + +template<class TInputValue, class TOutputValue> +void +SharkKMeansMachineLearningModel<TInputValue, TOutputValue> +::PrintSelf(std::ostream &os, itk::Indent indent) const +{ + // Call superclass implementation + Superclass::PrintSelf( os, indent ); +} +} //end namespace otb + +#endif diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.h b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..2d439f0b926deb41f5b97077ca89cf1000d7d9eb --- /dev/null +++ b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.h @@ -0,0 +1,74 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbSharkKMeansMachineLearningModelFactory_h +#define otbSharkKMeansMachineLearningModelFactory_h + +#include "itkObjectFactoryBase.h" +#include "itkImageIOBase.h" + +namespace otb +{ +/** \class SharkKMeansMachineLearningModelFactory + * \brief Creation of an instance of a SharkKMeansMachineLearningModel object using the object factory + * + * \ingroup OTBSupervised + */ +template <class TInputValue, class TTargetValue> +class ITK_EXPORT SharkKMeansMachineLearningModelFactory : public itk::ObjectFactoryBase +{ +public: + /** Standard class typedefs. */ + typedef SharkKMeansMachineLearningModelFactory Self; + typedef itk::ObjectFactoryBase Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Class methods used to interface with the registered factories. */ + virtual const char* GetITKSourceVersion(void) const; + virtual const char* GetDescription(void) const; + + /** Method for class instantiation. */ + itkFactorylessNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(SharkKMeansMachineLearningModelFactory, itk::ObjectFactoryBase); + + /** Register one factory of this type */ + static void RegisterOneFactory(void) + { + Pointer KMeansFactory = SharkKMeansMachineLearningModelFactory::New(); + itk::ObjectFactoryBase::RegisterFactory(KMeansFactory); + } + +protected: + SharkKMeansMachineLearningModelFactory(); + virtual ~SharkKMeansMachineLearningModelFactory(); + +private: + SharkKMeansMachineLearningModelFactory(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + +}; + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbSharkKMeansMachineLearningModelFactory.txx" +#endif + +#endif diff --git a/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.txx b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.txx new file mode 100644 index 0000000000000000000000000000000000000000..a698e7b1564fa81c8c5d590c4134914a8bd7f898 --- /dev/null +++ b/Modules/Learning/Supervised/include/otbSharkKMeansMachineLearningModelFactory.txx @@ -0,0 +1,69 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef otbSharkKMeansMachineLearningModelFactory_txx +#define otbSharkKMeansMachineLearningModelFactory_txx + +#include "otbSharkKMeansMachineLearningModelFactory.h" + +#include "itkCreateObjectFunction.h" +#include "otbSharkKMeansMachineLearningModel.h" +#include "itkVersion.h" + +namespace otb +{ + +template <class TInputValue, class TOutputValue> +SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> +::SharkKMeansMachineLearningModelFactory() +{ + + std::string classOverride = std::string("otbMachineLearningModel"); + std::string subclass = std::string("otbSharkKMeansMachineLearningModel"); + + this->RegisterOverride(classOverride.c_str(), + subclass.c_str(), + "Shark KMeans Machine Learning Model", + 1, + itk::CreateObjectFunction<SharkKMeansMachineLearningModel<TInputValue,TOutputValue> >::New()); +} + +template <class TInputValue, class TOutputValue> +SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> +::~SharkKMeansMachineLearningModelFactory() +{ +} + +template <class TInputValue, class TOutputValue> +const char* +SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> +::GetITKSourceVersion(void) const +{ + return ITK_SOURCE_VERSION; +} + +template <class TInputValue, class TOutputValue> +const char* +SharkKMeansMachineLearningModelFactory<TInputValue,TOutputValue> +::GetDescription() const +{ + return "Shark KMeans unsupervised machine learning model factory"; +} + +} // end namespace otb + +#endif