diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx index 0e0b08f3e26dd2d4cb8a0e4bddd7e036d9aae444..b8486999dc6f7d456e5ec6cdade49ba3bc3f91fc 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx @@ -195,6 +195,8 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> { itkExceptionMacro(<< "Error opening " << filename.c_str() ); } + // Add comment with model file name + ofs << "#" << m_RFModel.name() << std::endl; shark::TextOutArchive oa(ofs); m_RFModel.save(oa,0); } @@ -205,8 +207,25 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ::Load(const std::string & filename, const std::string & itkNotUsed(name)) { std::ifstream ifs(filename.c_str()); - shark::TextInArchive ia(ifs); - m_RFModel.load(ia,0); + if( ifs.good() ) + { + // Check if the first line is a comment and verify the name of the model in this case. + std::string line; + getline( ifs, line ); + if( line.at( 0 ) == '#' ) + { + if( line.find( m_RFModel.name() ) == std::string::npos ) + itkExceptionMacro( "The model file : " + filename + " cannot be read." ); + } + else + { + // rewind if first line is not a comment + ifs.clear(); + ifs.seekg( 0, std::ios::beg ); + } + shark::TextInArchive ia( ifs ); + m_RFModel.load( ia, 0 ); + } } template <class TInputValue, class TOutputValue> diff --git a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx index 5629ba45db9b47530e06a7bb6eed58c59df42e1d..31560a2dd0fbcdd5591f0518bc87b70c7fb77ac9 100644 --- a/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx +++ b/Modules/Learning/Unsupervised/include/otbSharkKMeansMachineLearningModel.txx @@ -30,6 +30,7 @@ #pragma GCC diagnostic ignored "-Wignored-qualifiers" #endif +#include "otb_shark.h" #include "otbSharkUtils.h" #include "shark/Algorithms/Trainers/NormalizeComponentsUnitVariance.h" //normalize #include "shark/Algorithms/KMeans.h" //k-means algorithm @@ -188,9 +189,8 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> { itkExceptionMacro( << "Error opening " << filename.c_str()); } + ofs << "#" << m_ClusteringModel->name() << std::endl; shark::TextOutArchive oa( ofs ); - std::string name = m_ClusteringModel->name(); - oa << name; m_ClusteringModel->save( oa, 1 ); } @@ -203,6 +203,7 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> std::ifstream ifs( filename.c_str()); if(ifs.good()) { + // Check if first line contains model name std::string line; std::getline(ifs, line); m_CanRead = line.find(m_ClusteringModel->name()) != std::string::npos; @@ -211,12 +212,8 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue> if(!m_CanRead) return; - // Go to the start of the file - ifs.seekg(0, std::ios::beg); shark::TextInArchive ia( ifs ); - std::string name; - ia & name; - m_ClusteringModel->load( ia, 1 ); + m_ClusteringModel->load( ia, 0 ); ifs.close(); }