Skip to content
Snippets Groups Projects
Commit 34e4b411 authored by Julien Michel's avatar Julien Michel
Browse files

ENH: Conform shark random forest model to the new internal API of MachineLearningModel

parent 05764a37
Branches
Tags
No related merge requests found
...@@ -20,7 +20,18 @@ ...@@ -20,7 +20,18 @@
#include "itkLightObject.h" #include "itkLightObject.h"
#include "otbMachineLearningModel.h" #include "otbMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#include "shark/Algorithms/Trainers/RFTrainer.h"
#pragma GCC diagnostic pop
#else
#include "shark/Algorithms/Trainers/RFTrainer.h" #include "shark/Algorithms/Trainers/RFTrainer.h"
#endif
namespace otb namespace otb
{ {
...@@ -42,8 +53,8 @@ public: ...@@ -42,8 +53,8 @@ public:
typedef typename Superclass::TargetSampleType TargetSampleType; typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType; typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType; typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef itk::FixedArray<ConfidenceValueType,1> ConfidenceSampleType; typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType; typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
/** Run-time type information (and related methods). */ /** Run-time type information (and related methods). */
itkNewMacro(Self); itkNewMacro(Self);
...@@ -51,8 +62,6 @@ public: ...@@ -51,8 +62,6 @@ public:
/** Train the machine learning model */ /** Train the machine learning model */
virtual void Train(); virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
/** Save the model to file */ /** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name=""); virtual void Save(const std::string & filename, const std::string & name="");
...@@ -60,8 +69,6 @@ public: ...@@ -60,8 +69,6 @@ public:
/** Load the model from file */ /** Load the model from file */
virtual void Load(const std::string & filename, const std::string & name=""); virtual void Load(const std::string & filename, const std::string & name="");
/** Classify all samples in InputListSample and fill TargetListSample with the associated label */
virtual void PredictAll() override;
/**\name Classification model file compatibility tests */ /**\name Classification model file compatibility tests */
//@{ //@{
/** Is the input model file readable and compatible with the corresponding classifier ? */ /** Is the input model file readable and compatible with the corresponding classifier ? */
...@@ -71,13 +78,6 @@ public: ...@@ -71,13 +78,6 @@ public:
virtual bool CanWriteFile(const std::string &); virtual bool CanWriteFile(const std::string &);
//@} //@}
/**\name Confidence accessors for batch mode */
//@{
/** Set the confidence samples (to be used before PredictAll) */
itkSetObjectMacro(ConfidenceListSample,ConfidenceListSampleType);
/** Get the confidence values (to be used after PredictAll) */
itkGetObjectMacro(ConfidenceListSample,ConfidenceListSampleType);
//@}
itkGetMacro(NumberOfTrees,unsigned int); itkGetMacro(NumberOfTrees,unsigned int);
itkSetMacro(NumberOfTrees,unsigned int); itkSetMacro(NumberOfTrees,unsigned int);
...@@ -94,10 +94,6 @@ public: ...@@ -94,10 +94,6 @@ public:
itkGetMacro(ComputeMargin, bool); itkGetMacro(ComputeMargin, bool);
itkSetMacro(ComputeMargin, bool); itkSetMacro(ComputeMargin, bool);
itkGetMacro(ConfidenceBatchMode, bool);
itkSetMacro(ConfidenceBatchMode, bool);
protected: protected:
/** Constructor */ /** Constructor */
SharkRandomForestsMachineLearningModel(); SharkRandomForestsMachineLearningModel();
...@@ -105,6 +101,12 @@ protected: ...@@ -105,6 +101,12 @@ protected:
/** Destructor */ /** Destructor */
virtual ~SharkRandomForestsMachineLearningModel(); virtual ~SharkRandomForestsMachineLearningModel();
/** Predict values using the model */
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=NULL) const ITK_OVERRIDE;
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = ITK_NULLPTR) const ITK_OVERRIDE;
/** PrintSelf method */ /** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const; void PrintSelf(std::ostream& os, itk::Indent indent) const;
...@@ -120,10 +122,8 @@ private: ...@@ -120,10 +122,8 @@ private:
unsigned int m_NodeSize; unsigned int m_NodeSize;
float m_OobRatio; float m_OobRatio;
bool m_ComputeMargin; bool m_ComputeMargin;
bool m_ConfidenceBatchMode;
/** Confidence list sample */ /** Confidence list sample */
typename ConfidenceListSampleType::Pointer m_ConfidenceListSample;
ConfidenceValueType ComputeConfidence(shark::RealVector probas, ConfidenceValueType ComputeConfidence(shark::RealVector probas,
bool computeMargin) const; bool computeMargin) const;
......
...@@ -21,7 +21,21 @@ ...@@ -21,7 +21,21 @@
#include <fstream> #include <fstream>
#include "itkMacro.h" #include "itkMacro.h"
#include "otbSharkRandomForestsMachineLearningModel.h" #include "otbSharkRandomForestsMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#include <shark/Models/Converter.h>
#pragma GCC diagnostic pop
#else
#include <shark/Models/Converter.h> #include <shark/Models/Converter.h>
#endif
#include "otbSharkUtils.h" #include "otbSharkUtils.h"
#include <algorithm> #include <algorithm>
...@@ -34,8 +48,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -34,8 +48,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
{ {
this->m_ConfidenceIndex = true; this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = false; this->m_IsRegressionSupported = false;
this->m_ConfidenceBatchMode = false; this->m_IsDoPredictBatchMultiThreaded = true;
m_ConfidenceListSample = ConfidenceListSampleType::New();
} }
...@@ -93,7 +106,7 @@ template <class TInputValue, class TOutputValue> ...@@ -93,7 +106,7 @@ template <class TInputValue, class TOutputValue>
typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType ::TargetSampleType
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & value, ConfidenceValueType *quality) const ::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const
{ {
shark::RealVector samples; shark::RealVector samples;
for(size_t i = 0; i < value.Size();i++) for(size_t i = 0; i < value.Size();i++)
...@@ -117,38 +130,48 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> ...@@ -117,38 +130,48 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
void void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue> SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::PredictAll() ::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const
{ {
// the samples to be predicted have to be set assert(input != ITK_NULLPTR);
assert(this->GetInputListSample() != ITK_NULLPTR); assert(targets != ITK_NULLPTR);
assert(input->Size()==targets->Size()&&"Input sample list and target label list do not have the same size.");
assert(((quality==ITK_NULLPTR)||(quality->Size()==input->Size()))&&"Quality samples list is not null and does not have the same size as input samples list");
if(startIndex+size>input->Size())
{
itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[");
}
std::vector<shark::RealVector> features; std::vector<shark::RealVector> features;
Shark::ListSampleToSharkVector(this->GetInputListSample(), features); Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features); shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
if(this->m_ConfidenceBatchMode)
auto probas = m_RFModel(inputSamples);
if(quality != ITK_NULLPTR)
{ {
//the confidence samples have to exist unsigned int id = startIndex;
assert(this->GetConfidenceListSample() != ITK_NULLPTR);
auto probas = m_RFModel(inputSamples);
ConfidenceListSampleType * confidences = this->GetConfidenceListSample();
confidences->Clear();
for(const auto& p : probas.elements()) for(const auto& p : probas.elements())
{ {
ConfidenceSampleType confidence; ConfidenceSampleType confidence;
auto conf = ComputeConfidence(p, m_ComputeMargin); auto conf = ComputeConfidence(p, m_ComputeMargin);
confidence[0] = static_cast<ConfidenceValueType>(conf); confidence[0] = static_cast<ConfidenceValueType>(conf);
confidences->PushBack(confidence); quality->SetMeasurementVector(id,confidence);
++id;
} }
} }
shark::ArgMaxConverter<shark::RFClassifier> amc; shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel; amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples); auto prediction = amc(inputSamples);
TargetListSampleType * targets = this->GetTargetListSample(); unsigned int id = startIndex;
targets->Clear();
for(const auto& p : prediction.elements()) for(const auto& p : prediction.elements())
{ {
TargetSampleType target; TargetSampleType target;
target[0] = static_cast<TOutputValue>(p); target[0] = static_cast<TOutputValue>(p);
targets->PushBack(target); targets->SetMeasurementVector(id,target);
} }
} }
......
...@@ -19,33 +19,43 @@ ...@@ -19,33 +19,43 @@
#define __SharkUtils_h #define __SharkUtils_h
//#include <shark/Algorithms/Trainers/RFTrainer.h> //#include <shark/Algorithms/Trainers/RFTrainer.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#include <shark/Data/Dataset.h>
#pragma GCC diagnostic pop
#else
#include <shark/Data/Dataset.h> #include <shark/Data/Dataset.h>
#endif
namespace otb namespace otb
{ {
namespace Shark namespace Shark
{ {
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<shark::RealVector> & output) template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<shark::RealVector> & output, const unsigned int & start, const unsigned int& size)
{ {
assert(listSample != ITK_NULLPTR); assert(listSample != ITK_NULLPTR);
assert(start+size<=listSample->Size());
// Sample index // Sample index
unsigned int sampleIdx = 0; unsigned int sampleIdx = start;
//Check for valid listSample //Check for valid listSample
if(listSample->Size()>0) if(listSample->Size()>0)
{ {
// Retrieve samples count // Retrieve samples count
output.clear(); output.clear();
// Build an iterator
typename T::ConstIterator sampleIt = listSample->Begin();
// Retrieve samples size alike // Retrieve samples size alike
const unsigned int sampleSize = listSample->GetMeasurementVectorSize(); const unsigned int sampleSize = listSample->GetMeasurementVectorSize();
// Fill the output vector // Fill the output vector
for(;sampleIt != listSample->End();++sampleIt,++sampleIdx) while(sampleIdx<start+size)
{ {
// Retrieve sample // Retrieve sample
typename T::MeasurementVectorType sample = sampleIt.GetMeasurementVector(); typename T::MeasurementVectorType sample = listSample->GetMeasurementVector(sampleIdx);
// Define a shark::RealVector // Define a shark::RealVector
shark::RealVector rv(sampleSize); shark::RealVector rv(sampleSize);
...@@ -55,33 +65,52 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto ...@@ -55,33 +65,52 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto
rv[i] = sample[i]; rv[i] = sample[i];
} }
output.push_back(rv); output.push_back(rv);
++sampleIdx;
} }
} }
} }
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<unsigned int> & output) template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<unsigned int> & output, const unsigned int & start, const unsigned int & size)
{ {
assert(listSample != ITK_NULLPTR); assert(listSample != ITK_NULLPTR);
assert(start+size<=listSample->Size());
// Sample index // Sample index
unsigned int sampleIdx = 0; unsigned int sampleIdx = start;
//Check for valid listSample //Check for valid listSample
if(listSample->Size()>0) if(listSample->Size()>0)
{ {
// Retrieve samples count
output.clear(); output.clear();
// Build an iterator
typename T::ConstIterator sampleIt = listSample->Begin();
// Fill the output vector // Fill the output vector
for(;sampleIt != listSample->End();++sampleIt,++sampleIdx) while(sampleIdx<start+size)
{ {
// Retrieve sample // Retrieve sample
typename T::MeasurementVectorType sample = sampleIt.GetMeasurementVector(); typename T::MeasurementVectorType sample = listSample->GetMeasurementVector(sampleIdx);
// Define a shark::RealVector
output.push_back(sample[0]); output.push_back(sample[0]);
++sampleIdx;
} }
} }
}
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<shark::RealVector> & output)
{
assert(listSample != ITK_NULLPTR);
ListSampleRangeToSharkVector(listSample,output,0U,static_cast<unsigned int>(listSample->Size()));
} }
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<unsigned int> & output)
{
assert(listSample != ITK_NULLPTR);
ListSampleRangeToSharkVector(listSample,output,0, static_cast<unsigned int>(listSample->Size()));
}
} }
} }
#endif #endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment