Skip to content
Snippets Groups Projects
Commit 96880ee0 authored by Julien Michel's avatar Julien Michel
Browse files

ENH: Make OpenMP sensitive to ITK_GLOBAL_DEFAULT_NUMBER_OF_THREADS

parent 1684e28a
Branches
Tags
No related merge requests found
...@@ -64,6 +64,10 @@ void ...@@ -64,6 +64,10 @@ void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Train() ::Train()
{ {
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
std::vector<shark::RealVector> features; std::vector<shark::RealVector> features;
std::vector<unsigned int> class_labels; std::vector<unsigned int> class_labels;
...@@ -142,12 +146,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -142,12 +146,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
{ {
itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"["); itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[");
} }
std::vector<shark::RealVector> features; std::vector<shark::RealVector> features;
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size); Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features); shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
auto probas = m_RFModel(inputSamples); auto probas = m_RFModel(inputSamples);
if(quality != ITK_NULLPTR) if(quality != ITK_NULLPTR)
...@@ -172,6 +179,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -172,6 +179,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
TargetSampleType target; TargetSampleType target;
target[0] = static_cast<TOutputValue>(p); target[0] = static_cast<TOutputValue>(p);
targets->SetMeasurementVector(id,target); targets->SetMeasurementVector(id,target);
++id;
} }
} }
......
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "otbMachineLearningModel.h" #include "otbMachineLearningModel.h"
#include "itkMultiThreader.h"
namespace otb namespace otb
{ {
...@@ -111,6 +113,8 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue> ...@@ -111,6 +113,8 @@ MachineLearningModel<TInputValue,TOutputValue,TConfidenceValue>
#pragma omp parallel shared(nb_threads,nb_batches) private(threadId) #pragma omp parallel shared(nb_threads,nb_batches) private(threadId)
{ {
// Get number of threads configured with ITK
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
nb_threads = omp_get_num_threads(); nb_threads = omp_get_num_threads();
threadId = omp_get_thread_num(); threadId = omp_get_thread_num();
nb_batches = std::min(nb_threads,(unsigned int)input->Size()); nb_batches = std::min(nb_threads,(unsigned int)input->Size());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment