Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
What's new
10
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Main Repositories
otb
Commits
a11653ed
Commit
a11653ed
authored
Nov 19, 2017
by
Jordi Inglada
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
ENH: API change of Shark's RF
parent
14935049
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
13 deletions
+9
-13
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h
...vised/include/otbSharkRandomForestsMachineLearningModel.h
+3
-2
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx
...sed/include/otbSharkRandomForestsMachineLearningModel.txx
+6
-11
No files found.
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h
View file @
a11653ed
...
...
@@ -34,6 +34,7 @@
#pragma GCC diagnostic ignored "-Wcast-align"
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
#endif
#include <shark/Models/Classifier.h>
#include "otb_shark.h"
#include "shark/Algorithms/Trainers/RFTrainer.h"
#if defined(__GNUC__) || defined(__clang__)
...
...
@@ -154,8 +155,8 @@ private:
SharkRandomForestsMachineLearningModel
(
const
Self
&
);
//purposely not implemented
void
operator
=
(
const
Self
&
);
//purposely not implemented
shark
::
RFClassifier
m_RFModel
;
shark
::
RFTrainer
m_RFTrainer
;
shark
::
RFClassifier
<
unsigned
int
>
m_RFModel
;
shark
::
RFTrainer
<
unsigned
int
>
m_RFTrainer
;
unsigned
int
m_NumberOfTrees
;
unsigned
int
m_MTry
;
...
...
Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx
View file @
a11653ed
...
...
@@ -32,7 +32,6 @@
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#endif
#include <shark/Models/Converter.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
...
...
@@ -82,7 +81,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_RFTrainer.setMTry(m_MTry);
m_RFTrainer.setNTrees(m_NumberOfTrees);
m_RFTrainer.setNodeSize(m_NodeSize);
m_RFTrainer.setOOBratio(m_OobRatio);
//
m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.train(m_RFModel, TrainSamples);
}
...
...
@@ -125,13 +124,11 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
if (quality != ITK_NULLPTR)
{
shark::RealVector probas = m_RFModel(samples);
shark::RealVector probas = m_RFModel
.decisionFunction()
(samples);
(*quality) = ComputeConfidence(probas, m_ComputeMargin);
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
unsigned int res;
amc.eval(samples, res);
unsigned int res{0};
m_RFModel.eval(samples, res);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
return target;
...
...
@@ -163,7 +160,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
if(quality != ITK_NULLPTR)
{
shark::Data<shark::RealVector> probas = m_RFModel(inputSamples);
shark::Data<shark::RealVector> probas = m_RFModel
.decisionFunction()
(inputSamples);
unsigned int id = startIndex;
for(shark::RealVector && p : probas.elements())
{
...
...
@@ -175,9 +172,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
auto prediction = m_RFModel(inputSamples);
unsigned int id = startIndex;
for(const auto& p : prediction.elements())
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment