Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
7
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Main Repositories
otb
Commits
046afb48
Commit
046afb48
authored
Mar 06, 2018
by
Jordi Inglada
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ENH: use the class dictionary in Shark RF and store it in the serialised model
parent
f96785af
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
4 deletions
+21
-4
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h
...vised/include/otbSharkRandomForestsMachineLearningModel.h
+1
-0
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx
...sed/include/otbSharkRandomForestsMachineLearningModel.txx
+20
-4
No files found.
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h
View file @
046afb48
...
@@ -157,6 +157,7 @@ private:
...
@@ -157,6 +157,7 @@ private:
shark
::
RFClassifier
<
unsigned
int
>
m_RFModel
;
shark
::
RFClassifier
<
unsigned
int
>
m_RFModel
;
shark
::
RFTrainer
<
unsigned
int
>
m_RFTrainer
;
shark
::
RFTrainer
<
unsigned
int
>
m_RFTrainer
;
std
::
vector
<
unsigned
int
>
m_ClassDictionary
;
unsigned
int
m_NumberOfTrees
;
unsigned
int
m_NumberOfTrees
;
unsigned
int
m_MTry
;
unsigned
int
m_MTry
;
...
...
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx
View file @
046afb48
...
@@ -75,6 +75,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
...
@@ -75,6 +75,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
//Set parameters
...
@@ -130,7 +131,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
...
@@ -130,7 +131,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
unsigned int res{0};
unsigned int res{0};
m_RFModel.eval(samples, res);
m_RFModel.eval(samples, res);
TargetSampleType target;
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
target[0] =
m_ClassDictionary[
static_cast<TOutputValue>(res)
]
;
return target;
return target;
}
}
...
@@ -154,9 +155,9 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
...
@@ -154,9 +155,9 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
#ifdef _OPENMP
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
#endif
if(quality != ITK_NULLPTR)
if(quality != ITK_NULLPTR)
{
{
...
@@ -177,7 +178,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
...
@@ -177,7 +178,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
for(const auto& p : prediction.elements())
for(const auto& p : prediction.elements())
{
{
TargetSampleType target;
TargetSampleType target;
target[0] = static_cast<TOutputValue>(p);
target[0] =
m_ClassDictionary[
static_cast<TOutputValue>(p)
]
;
targets->SetMeasurementVector(id,target);
targets->SetMeasurementVector(id,target);
++id;
++id;
}
}
...
@@ -195,6 +196,12 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
...
@@ -195,6 +196,12 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
}
// Add comment with model file name
// Add comment with model file name
ofs << "#" << m_RFModel.name() << std::endl;
ofs << "#" << m_RFModel.name() << std::endl;
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
{
ofs << l << " ";
}
ofs << std::endl;
shark::TextOutArchive oa(ofs);
shark::TextOutArchive oa(ofs);
m_RFModel.save(oa,0);
m_RFModel.save(oa,0);
}
}
...
@@ -221,6 +228,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
...
@@ -221,6 +228,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
ifs.clear();
ifs.clear();
ifs.seekg( 0, std::ios::beg );
ifs.seekg( 0, std::ios::beg );
}
}
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
{
unsigned int label;
ifs >> label;
m_ClassDictionary[i]=label;
}
shark::TextInArchive ia( ifs );
shark::TextInArchive ia( ifs );
m_RFModel.load( ia, 0 );
m_RFModel.load( ia, 0 );
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment