From a11653ed84ae1524a6f93d472c3cc63c9b4bf870 Mon Sep 17 00:00:00 2001 From: Jordi Inglada Date: Sun, 19 Nov 2017 22:06:12 +0100 Subject: [PATCH] ENH: API change of Shark's RF --- .../otbSharkRandomForestsMachineLearningModel.h | 5 +++-- ...tbSharkRandomForestsMachineLearningModel.txx | 17 ++++++----------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h index 8d2c0c8123..484de95428 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.h @@ -34,6 +34,7 @@ #pragma GCC diagnostic ignored "-Wcast-align" #pragma GCC diagnostic ignored "-Wunknown-pragmas" #endif +#include #include "otb_shark.h" #include "shark/Algorithms/Trainers/RFTrainer.h" #if defined(__GNUC__) || defined(__clang__) @@ -154,8 +155,8 @@ private: SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented void operator =(const Self&); //purposely not implemented - shark::RFClassifier m_RFModel; - shark::RFTrainer m_RFTrainer; + shark::RFClassifier m_RFModel; + shark::RFTrainer m_RFTrainer; unsigned int m_NumberOfTrees; unsigned int m_MTry; diff --git a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx index 207f1abdd7..4d7d9e83e3 100644 --- a/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx +++ b/Modules/Learning/Supervised/include/otbSharkRandomForestsMachineLearningModel.txx @@ -32,7 +32,6 @@ #pragma GCC diagnostic ignored "-Woverloaded-virtual" #pragma GCC diagnostic ignored "-Wignored-qualifiers" #endif -#include #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif @@ -82,7 +81,7 @@ SharkRandomForestsMachineLearningModel m_RFTrainer.setMTry(m_MTry); m_RFTrainer.setNTrees(m_NumberOfTrees); m_RFTrainer.setNodeSize(m_NodeSize); - m_RFTrainer.setOOBratio(m_OobRatio); + // m_RFTrainer.setOOBratio(m_OobRatio); m_RFTrainer.train(m_RFModel, TrainSamples); } @@ -125,13 +124,11 @@ SharkRandomForestsMachineLearningModel } if (quality != ITK_NULLPTR) { - shark::RealVector probas = m_RFModel(samples); + shark::RealVector probas = m_RFModel.decisionFunction()(samples); (*quality) = ComputeConfidence(probas, m_ComputeMargin); } - shark::ArgMaxConverter amc; - amc.decisionFunction() = m_RFModel; - unsigned int res; - amc.eval(samples, res); + unsigned int res{0}; + m_RFModel.eval(samples, res); TargetSampleType target; target[0] = static_cast(res); return target; @@ -163,7 +160,7 @@ SharkRandomForestsMachineLearningModel if(quality != ITK_NULLPTR) { - shark::Data probas = m_RFModel(inputSamples); + shark::Data probas = m_RFModel.decisionFunction()(inputSamples); unsigned int id = startIndex; for(shark::RealVector && p : probas.elements()) { @@ -175,9 +172,7 @@ SharkRandomForestsMachineLearningModel } } - shark::ArgMaxConverter amc; - amc.decisionFunction() = m_RFModel; - auto prediction = amc(inputSamples); + auto prediction = m_RFModel(inputSamples); unsigned int id = startIndex; for(const auto& p : prediction.elements()) { -- GitLab