Skip to content
Snippets Groups Projects
Commit a11653ed authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: API change of Shark's RF

parent 14935049
No related branches found
No related tags found
2 merge requests!31WIP: Update shark rf,!26Update Shark Random Forest implementation
......@@ -34,6 +34,7 @@
#pragma GCC diagnostic ignored "-Wcast-align"
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
#endif
#include <shark/Models/Classifier.h>
#include "otb_shark.h"
#include "shark/Algorithms/Trainers/RFTrainer.h"
#if defined(__GNUC__) || defined(__clang__)
......@@ -154,8 +155,8 @@ private:
SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
shark::RFClassifier m_RFModel;
shark::RFTrainer m_RFTrainer;
shark::RFClassifier<unsigned int> m_RFModel;
shark::RFTrainer<unsigned int> m_RFTrainer;
unsigned int m_NumberOfTrees;
unsigned int m_MTry;
......
......@@ -32,7 +32,6 @@
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#endif
#include <shark/Models/Converter.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
......@@ -82,7 +81,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_RFTrainer.setMTry(m_MTry);
m_RFTrainer.setNTrees(m_NumberOfTrees);
m_RFTrainer.setNodeSize(m_NodeSize);
m_RFTrainer.setOOBratio(m_OobRatio);
// m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.train(m_RFModel, TrainSamples);
}
......@@ -125,13 +124,11 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
if (quality != ITK_NULLPTR)
{
shark::RealVector probas = m_RFModel(samples);
shark::RealVector probas = m_RFModel.decisionFunction()(samples);
(*quality) = ComputeConfidence(probas, m_ComputeMargin);
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
unsigned int res;
amc.eval(samples, res);
unsigned int res{0};
m_RFModel.eval(samples, res);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
return target;
......@@ -163,7 +160,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
if(quality != ITK_NULLPTR)
{
shark::Data<shark::RealVector> probas = m_RFModel(inputSamples);
shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples);
unsigned int id = startIndex;
for(shark::RealVector && p : probas.elements())
{
......@@ -175,9 +172,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
auto prediction = m_RFModel(inputSamples);
unsigned int id = startIndex;
for(const auto& p : prediction.elements())
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment