Commit 046afb48 authored by Jordi Inglada's avatar Jordi Inglada

ENH: use the class dictionary in Shark RF and store it in the serialised model

parent f96785af
......@@ -157,6 +157,7 @@ private:
shark::RFClassifier<unsigned int> m_RFModel;
shark::RFTrainer<unsigned int> m_RFTrainer;
std::vector<unsigned int> m_ClassDictionary;
unsigned int m_NumberOfTrees;
unsigned int m_MTry;
......
......@@ -75,6 +75,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
......@@ -130,7 +131,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
unsigned int res{0};
m_RFModel.eval(samples, res);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
target[0] = m_ClassDictionary[static_cast<TOutputValue>(res)];
return target;
}
......@@ -154,9 +155,9 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
#ifdef _OPENMP
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
#endif
if(quality != ITK_NULLPTR)
{
......@@ -177,7 +178,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
for(const auto& p : prediction.elements())
{
TargetSampleType target;
target[0] = static_cast<TOutputValue>(p);
target[0] = m_ClassDictionary[static_cast<TOutputValue>(p)];
targets->SetMeasurementVector(id,target);
++id;
}
......@@ -195,6 +196,12 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
// Add comment with model file name
ofs << "#" << m_RFModel.name() << std::endl;
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
{
ofs << l << " ";
}
ofs << std::endl;
shark::TextOutArchive oa(ofs);
m_RFModel.save(oa,0);
}
......@@ -221,6 +228,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
ifs.clear();
ifs.seekg( 0, std::ios::beg );
}
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
{
unsigned int label;
ifs >> label;
m_ClassDictionary[i]=label;
}
shark::TextInArchive ia( ifs );
m_RFModel.load( ia, 0 );
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment