From 60d96d0159f6bf324684d811a0b509039d7a289c Mon Sep 17 00:00:00 2001
From: Ludovic Hussonnois <ludovic.hussonnois@c-s.fr>
Date: Fri, 3 Mar 2017 16:41:41 +0100
Subject: [PATCH] BUG: Correct model reading and training.

---
 .../otbSharkKMeansMachineLearningModel.h      | 23 +++++------
 .../otbSharkKMeansMachineLearningModel.txx    | 41 +++++++++++++++----
 2 files changed, 43 insertions(+), 21 deletions(-)

diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h
index 926d59bc5b..1de147e08f 100644
--- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h
+++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.h
@@ -18,8 +18,6 @@
 #ifndef otbSharkKMeansMachineLearningModel_h
 #define otbSharkKMeansMachineLearningModel_h
 
-
-
 #include "itkLightObject.h"
 #include "otbMachineLearningModel.h"
 
@@ -71,15 +69,16 @@ public:
   typedef itk::SmartPointer<Self> Pointer;
   typedef itk::SmartPointer<const Self> ConstPointer;
 
-  typedef typename Superclass::InputValueType InputValueType;
-  typedef typename Superclass::InputSampleType InputSampleType;
-  typedef typename Superclass::InputListSampleType InputListSampleType;
-  typedef typename Superclass::TargetValueType TargetValueType;
-  typedef typename Superclass::TargetSampleType TargetSampleType;
-  typedef typename Superclass::TargetListSampleType TargetListSampleType;
-  typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
-  typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
-  typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
+  typedef typename Superclass::InputValueType             InputValueType;
+  typedef typename Superclass::InputSampleType            InputSampleType;
+  typedef typename Superclass::InputListSampleType        InputListSampleType;
+  typedef typename Superclass::TargetValueType            TargetValueType;
+  typedef typename Superclass::TargetSampleType           TargetSampleType;
+  typedef typename Superclass::TargetListSampleType       TargetListSampleType;
+  typedef typename Superclass::ConfidenceValueType        ConfidenceValueType;
+  typedef typename Superclass::ConfidenceSampleType       ConfidenceSampleType;
+  typedef typename Superclass::ConfidenceListSampleType   ConfidenceListSampleType;
+
 
   typedef HardClusteringModel<RealVector> ClusteringModelType;
   typedef ClusteringModelType::OutputType ClusteringOutputType;
@@ -153,7 +152,7 @@ private:
 
 
   /** Centroids results form kMeans */
-  Centroids centroids;
+  Centroids m_Centroids;
 
 
   /** shark Model could be SoftClusteringModel or HardClusteringModel */
diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx
index 7f74215fdb..3428a57acb 100644
--- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx
+++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx
@@ -51,7 +51,7 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
         m_Normalized( true ), m_K(2), m_MaximumNumberOfIterations( 0 )
 {
   // Default set HardClusteringModel
-  m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &centroids ));
+  m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &m_Centroids ));
 }
 
 
@@ -77,7 +77,8 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
     data = NormalizeData( data );
 
   // Use a Hard Clustering Model for classification
-  kMeans( data, m_K, centroids, m_MaximumNumberOfIterations );
+  kMeans( data, m_K, m_Centroids, m_MaximumNumberOfIterations );
+  m_ClusteringModel = boost::shared_ptr<ClusteringModelType>(new ClusteringModelType( &m_Centroids ));
 }
 
 template<class TInputValue, class TOutputValue>
@@ -133,9 +134,9 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
 
   // input list sample and target list sample should be initialized and without
   assert( input->Size() == targets->Size() && "Input sample list and target label list do not have the same size." );
-  assert((( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size())) &&
-         "Quality samples list is not null and does not have the same size as input samples list" );
-  if( startIndex + size > input->Size())
+  assert( ( ( quality == ITK_NULLPTR ) || ( quality->Size() == input->Size() ) ) &&
+          "Quality samples list is not null and does not have the same size as input samples list" );
+  if( startIndex + size > input->Size() )
     {
     itkExceptionMacro(
             <<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[" );
@@ -146,7 +147,17 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
   Shark::ListSampleRangeToSharkVector( input, features, startIndex, size );
   Data<RealVector> inputSamples = shark::createDataFromRange( features );
 
-  Data<ClusteringOutputType> clusters = ( *m_ClusteringModel )( inputSamples );
+  Data<ClusteringOutputType> clusters;
+  try
+    {
+     clusters = ( *m_ClusteringModel )( inputSamples );
+    }
+  catch( ... )
+    {
+    itkExceptionMacro( "Failed to run clustering classification. "
+                               "The number of features of input samples and the model could differ.");
+    }
+
   unsigned int id = startIndex;
   for( const auto &p : clusters.elements() )
     {
@@ -189,13 +200,25 @@ void
 SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
 ::Load(const std::string &filename, const std::string & itkNotUsed( name ))
 {
+  m_CanRead = false;
   std::ifstream ifs( filename.c_str());
+  if(ifs.good())
+    {
+    std::string line;
+    std::getline(ifs, line);
+    m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos;
+    }
+
+  if(!m_CanRead)
+    return;
+
+  // Go to the start of the file
+  ifs.seekg(0, std::ios::beg);
   shark::TextInArchive ia( ifs );
   std::string name;
-  ia >> name;
-  if(name != m_ClusteringModel->name())
-    m_CanRead = false;
+  ia & name;
   m_ClusteringModel->load( ia, 1 );
+  ifs.close();
 }
 
 template<class TInputValue, class TOutputValue>
-- 
GitLab