Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
10
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Main Repositories
otb
Commits
046afb48
Commit
046afb48
authored
Mar 06, 2018
by
Jordi Inglada
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ENH: use the class dictionary in Shark RF and store it in the serialised model
parent
f96785af
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
4 deletions
+21
-4
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h
...vised/include/otbSharkRandomForestsMachineLearningModel.h
+1
-0
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx
...sed/include/otbSharkRandomForestsMachineLearningModel.txx
+20
-4
No files found.
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h
View file @
046afb48
...
...
@@ -157,6 +157,7 @@ private:
shark
::
RFClassifier
<
unsigned
int
>
m_RFModel
;
shark
::
RFTrainer
<
unsigned
int
>
m_RFTrainer
;
std
::
vector
<
unsigned
int
>
m_ClassDictionary
;
unsigned
int
m_NumberOfTrees
;
unsigned
int
m_MTry
;
...
...
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx
View file @
046afb48
...
...
@@ -75,6 +75,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
...
...
@@ -130,7 +131,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
unsigned int res{0};
m_RFModel.eval(samples, res);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
target[0] =
m_ClassDictionary[
static_cast<TOutputValue>(res)
]
;
return target;
}
...
...
@@ -154,9 +155,9 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
#ifdef _OPENMP
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
#endif
if(quality != ITK_NULLPTR)
{
...
...
@@ -177,7 +178,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
for(const auto& p : prediction.elements())
{
TargetSampleType target;
target[0] = static_cast<TOutputValue>(p);
target[0] =
m_ClassDictionary[
static_cast<TOutputValue>(p)
]
;
targets->SetMeasurementVector(id,target);
++id;
}
...
...
@@ -195,6 +196,12 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
// Add comment with model file name
ofs << "#" << m_RFModel.name() << std::endl;
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
{
ofs << l << " ";
}
ofs << std::endl;
shark::TextOutArchive oa(ofs);
m_RFModel.save(oa,0);
}
...
...
@@ -221,6 +228,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
ifs.clear();
ifs.seekg( 0, std::ios::beg );
}
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
{
unsigned int label;
ifs >> label;
m_ClassDictionary[i]=label;
}
shark::TextInArchive ia( ifs );
m_RFModel.load( ia, 0 );
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment