diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h index 484de95428927295fe5ca2dc936ab58ee4b28671..7f782d73c8cf1166547cbb9bd329a43d17a36051 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h @@ -157,6 +157,7 @@ private: shark::RFClassifier m_RFModel; shark::RFTrainer m_RFTrainer; + std::vector m_ClassDictionary; unsigned int m_NumberOfTrees; unsigned int m_MTry; diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx index 4d7d9e83e30bd0d2ee6601ae851b35f00ef7a0ab..eae8ca2ccc32b3653e031c0247917bd28488b3f7 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx @@ -75,6 +75,7 @@ SharkRandomForestsMachineLearningModel Shark::ListSampleToSharkVector(this->GetInputListSample(), features); Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels); + Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary); shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels); //Set parameters @@ -130,7 +131,7 @@ SharkRandomForestsMachineLearningModel unsigned int res{0}; m_RFModel.eval(samples, res); TargetSampleType target; - target[0] = static_cast(res); + target[0] = m_ClassDictionary[static_cast(res)]; return target; } @@ -154,9 +155,9 @@ SharkRandomForestsMachineLearningModel Shark::ListSampleRangeToSharkVector(input, features,startIndex,size); shark::Data inputSamples = shark::createDataFromRange(features); - #ifdef _OPENMP +#ifdef _OPENMP omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads()); - #endif +#endif if(quality != ITK_NULLPTR) { @@ -177,7 +178,7 @@ SharkRandomForestsMachineLearningModel for(const auto& p : prediction.elements()) { TargetSampleType target; - target[0] = static_cast(p); + target[0] = m_ClassDictionary[static_cast(p)]; targets->SetMeasurementVector(id,target); ++id; } @@ -195,6 +196,12 @@ SharkRandomForestsMachineLearningModel } // Add comment with model file name ofs << "#" << m_RFModel.name() << std::endl; + ofs << m_ClassDictionary.size() << " "; + for(const auto& l : m_ClassDictionary) + { + ofs << l << " "; + } + ofs << std::endl; shark::TextOutArchive oa(ofs); m_RFModel.save(oa,0); } @@ -221,6 +228,15 @@ SharkRandomForestsMachineLearningModel ifs.clear(); ifs.seekg( 0, std::ios::beg ); } + size_t nbLabels{0}; + ifs >> nbLabels; + m_ClassDictionary.resize(nbLabels); + for(size_t i=0; i> label; + m_ClassDictionary[i]=label; + } shark::TextInArchive ia( ifs ); m_RFModel.load( ia, 0 ); }