Commit 679ed258 authored by Jordi Inglada's avatar Jordi Inglada

ENH: compatibility for serialised models without class dictionnary

parent 046afb48
......@@ -135,6 +135,10 @@ public:
/** If true, margin confidence value will be computed */
itkSetMacro(ComputeMargin, bool);
/** If true, class labels will be normalised in [0 ... nbClasses] */
itkGetMacro(NormalizeClassLabels, bool);
itkSetMacro(NormalizeClassLabels, bool);
protected:
/** Constructor */
SharkRandomForestsMachineLearningModel();
......@@ -158,6 +162,7 @@ private:
shark::RFClassifier<unsigned int> m_RFModel;
shark::RFTrainer<unsigned int> m_RFTrainer;
std::vector<unsigned int> m_ClassDictionary;
bool m_NormalizeClassLabels;
unsigned int m_NumberOfTrees;
unsigned int m_MTry;
......
......@@ -51,6 +51,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = false;
this->m_IsDoPredictBatchMultiThreaded = true;
this->m_NormalizeClassLabels = true;
}
......@@ -75,7 +76,10 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
if(m_NormalizeClassLabels)
{
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
}
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
......@@ -131,7 +135,14 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
unsigned int res{0};
m_RFModel.eval(samples, res);
TargetSampleType target;
target[0] = m_ClassDictionary[static_cast<TOutputValue>(res)];
if(m_NormalizeClassLabels)
{
target[0] = m_ClassDictionary[static_cast<TOutputValue>(res)];
}
else
{
target[0] = static_cast<TOutputValue>(res);
}
return target;
}
......@@ -178,7 +189,14 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
for(const auto& p : prediction.elements())
{
TargetSampleType target;
target[0] = m_ClassDictionary[static_cast<TOutputValue>(p)];
if(m_NormalizeClassLabels)
{
target[0] = m_ClassDictionary[static_cast<TOutputValue>(p)];
}
else
{
target[0] = static_cast<TOutputValue>(p);
}
targets->SetMeasurementVector(id,target);
++id;
}
......@@ -195,13 +213,18 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
itkExceptionMacro(<< "Error opening " << filename.c_str() );
}
// Add comment with model file name
ofs << "#" << m_RFModel.name() << std::endl;
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
ofs << "#" << m_RFModel.name();
if(m_NormalizeClassLabels) ofs << " with_dictionary";
ofs << std::endl;
if(m_NormalizeClassLabels)
{
ofs << l << " ";
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
{
ofs << l << " ";
}
ofs << std::endl;
}
ofs << std::endl;
shark::TextOutArchive oa(ofs);
m_RFModel.save(oa,0);
}
......@@ -221,6 +244,10 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
{
if( line.find( m_RFModel.name() ) == std::string::npos )
itkExceptionMacro( "The model file : " + filename + " cannot be read." );
if( line.find( "with_dictionary" ) == std::string::npos )
{
m_NormalizeClassLabels=false;
}
}
else
{
......@@ -228,14 +255,17 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
ifs.clear();
ifs.seekg( 0, std::ios::beg );
}
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
if(m_NormalizeClassLabels)
{
unsigned int label;
ifs >> label;
m_ClassDictionary[i]=label;
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
{
unsigned int label;
ifs >> label;
m_ClassDictionary[i]=label;
}
}
shark::TextInArchive ia( ifs );
m_RFModel.load( ia, 0 );
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment