Commit a11653ed authored by Jordi Inglada's avatar Jordi Inglada

ENH: API change of Shark's RF

parent 14935049
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#pragma GCC diagnostic ignored "-Wcast-align" #pragma GCC diagnostic ignored "-Wcast-align"
#pragma GCC diagnostic ignored "-Wunknown-pragmas" #pragma GCC diagnostic ignored "-Wunknown-pragmas"
#endif #endif
#include <shark/Models/Classifier.h>
#include "otb_shark.h" #include "otb_shark.h"
#include "shark/Algorithms/Trainers/RFTrainer.h" #include "shark/Algorithms/Trainers/RFTrainer.h"
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
...@@ -154,8 +155,8 @@ private: ...@@ -154,8 +155,8 @@ private:
SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented void operator =(const Self&); //purposely not implemented
shark::RFClassifier m_RFModel; shark::RFClassifier<unsigned int> m_RFModel;
shark::RFTrainer m_RFTrainer; shark::RFTrainer<unsigned int> m_RFTrainer;
unsigned int m_NumberOfTrees; unsigned int m_NumberOfTrees;
unsigned int m_MTry; unsigned int m_MTry;
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
#pragma GCC diagnostic ignored "-Woverloaded-virtual" #pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers" #pragma GCC diagnostic ignored "-Wignored-qualifiers"
#endif #endif
#include <shark/Models/Converter.h>
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
...@@ -82,7 +81,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -82,7 +81,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
m_RFTrainer.setMTry(m_MTry); m_RFTrainer.setMTry(m_MTry);
m_RFTrainer.setNTrees(m_NumberOfTrees); m_RFTrainer.setNTrees(m_NumberOfTrees);
m_RFTrainer.setNodeSize(m_NodeSize); m_RFTrainer.setNodeSize(m_NodeSize);
m_RFTrainer.setOOBratio(m_OobRatio); // m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.train(m_RFModel, TrainSamples); m_RFTrainer.train(m_RFModel, TrainSamples);
} }
...@@ -125,13 +124,11 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -125,13 +124,11 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
} }
if (quality != ITK_NULLPTR) if (quality != ITK_NULLPTR)
{ {
shark::RealVector probas = m_RFModel(samples); shark::RealVector probas = m_RFModel.decisionFunction()(samples);
(*quality) = ComputeConfidence(probas, m_ComputeMargin); (*quality) = ComputeConfidence(probas, m_ComputeMargin);
} }
shark::ArgMaxConverter<shark::RFClassifier> amc; unsigned int res{0};
amc.decisionFunction() = m_RFModel; m_RFModel.eval(samples, res);
unsigned int res;
amc.eval(samples, res);
TargetSampleType target; TargetSampleType target;
target[0] = static_cast<TOutputValue>(res); target[0] = static_cast<TOutputValue>(res);
return target; return target;
...@@ -163,7 +160,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -163,7 +160,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
if(quality != ITK_NULLPTR) if(quality != ITK_NULLPTR)
{ {
shark::Data<shark::RealVector> probas = m_RFModel(inputSamples); shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples);
unsigned int id = startIndex; unsigned int id = startIndex;
for(shark::RealVector && p : probas.elements()) for(shark::RealVector && p : probas.elements())
{ {
...@@ -175,9 +172,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -175,9 +172,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
} }
} }
shark::ArgMaxConverter<shark::RFClassifier> amc; auto prediction = m_RFModel(inputSamples);
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
unsigned int id = startIndex; unsigned int id = startIndex;
for(const auto& p : prediction.elements()) for(const auto& p : prediction.elements())
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment