From 60d96d0159f6bf324684d811a0b509039d7a289c Mon Sep 17 00:00:00 2001 From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr> Date: Fri, 3 Mar 2017 16:41:41 +0100 Subject: [PATCH] BUG: Correct model reading and training. --- .../otbSharkKMeansMachineLearningModel.h | 23 +++++------ .../otbSharkKMeansMachineLearningModel.txx | 41 +++++++++++++++---- 2 files changed, 43 insertions(+), 21 deletions(-) diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h index 926d59bc5b..1de147e08f 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h @@ -18,8 +18,6 @@ #ifndef otbSharkKMeansMachineLearningModel_h #define otbSharkKMeansMachineLearningModel_h - - #include "itkLightObject.h" #include "otbMachineLearningModel.h" @@ -71,15 +69,16 @@ public: typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; - typedef typename Superclass::InputValueType InputValueType; - typedef typename Superclass::InputSampleType InputSampleType; - typedef typename Superclass::InputListSampleType InputListSampleType; - typedef typename Superclass::TargetValueType TargetValueType; - typedef typename Superclass::TargetSampleType TargetSampleType; - typedef typename Superclass::TargetListSampleType TargetListSampleType; - typedef typename Superclass::ConfidenceValueType ConfidenceValueType; - typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; - typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; + typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; + typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + typedef HardClusteringModel<RealVector> ClusteringModelType; typedef ClusteringModelType::OutputType ClusteringOutputType; @@ -153,7 +152,7 @@ private: /** Centroids results form kMeans */ - Centroids centroids; + Centroids m_Centroids; /** shark Model could be SoftClusteringModel or HardClusteringModel */ diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx index 7f74215fdb..3428a57acb 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx @@ -51,7 +51,7 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> m_Normalized( true ), m_K(2), m_MaximumNumberOfIterations( 0 ) { // Default set HardClusteringModel - m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( ¢roids )); + m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &m_Centroids )); } @@ -77,7 +77,8 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> data = NormalizeData( data ); // Use a Hard Clustering Model for classification - kMeans( data, m_K, centroids, m_MaximumNumberOfIterations ); + kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations ); + m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &m_Centroids )); } template<class TInputValue, class TOutputValue> @@ -133,9 +134,9 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> // input list sample and target list sample should be initialized and without assert( input->Size() == targets->Size() && "Input sample list and target label list do not have the same size." ); - assert((( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size())) && - "Quality samples list is not null and does not have the same size as input samples list" ); - if( startIndex + size > input->Size()) + assert( ( ( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size() ) ) && + "Quality samples list is not null and does not have the same size as input samples list" ); + if( startIndex + size > input->Size() ) { itkExceptionMacro( <<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[" ); @@ -146,7 +147,17 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> Shark::ListSampleRangeToSharkVector( input, features, startIndex, size ); Data<RealVector> inputSamples = shark::createDataFromRange( features ); - Data<ClusteringOutputType> clusters = ( *m_ClusteringModel )( inputSamples ); + Data<ClusteringOutputType> clusters; + try + { + clusters = ( *m_ClusteringModel )( inputSamples ); + } + catch( ... ) + { + itkExceptionMacro( "Failed to run clustering classification. " + "The number of features of input samples and the model could differ."); + } + unsigned int id = startIndex; for( const auto &p : clusters.elements() ) { @@ -189,13 +200,25 @@ void SharkKMeansMachineLearningModel<TInputValue, TOutputValue> ::Load(const std::string &filename, const std::string & itkNotUsed( name )) { + m_CanRead = false; std::ifstream ifs( filename.c_str()); + if(ifs.good()) + { + std::string line; + std::getline(ifs, line); + m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos; + } + + if(!m_CanRead) + return; + + // Go to the start of the file + ifs.seekg(0, std::ios::beg); shark::TextInArchive ia( ifs ); std::string name; - ia >> name; - if(name != m_ClusteringModel->name()) - m_CanRead = false; + ia & name; m_ClusteringModel->load( ia, 1 ); + ifs.close(); } template<class TInputValue, class TOutputValue> -- GitLab