Commit 34e4b411 authored by Julien Michel's avatar Julien Michel

ENH: Conform shark random forest model to the new internal API of MachineLearningModel

parent 05764a37
......@@ -20,7 +20,18 @@
#include "itkLightObject.h"
#include "otbMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#include "shark/Algorithms/Trainers/RFTrainer.h"
#pragma GCC diagnostic pop
#else
#include "shark/Algorithms/Trainers/RFTrainer.h"
#endif
namespace otb
{
......@@ -42,8 +53,8 @@ public:
typedef typename Superclass::TargetSampleType TargetSampleType;
typedef typename Superclass::TargetListSampleType TargetListSampleType;
typedef typename Superclass::ConfidenceValueType ConfidenceValueType;
typedef itk::FixedArray<ConfidenceValueType,1> ConfidenceSampleType;
typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType;
typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType;
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
/** Run-time type information (and related methods). */
itkNewMacro(Self);
......@@ -51,8 +62,6 @@ public:
/** Train the machine learning model */
virtual void Train();
/** Predict values using the model */
virtual TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
/** Save the model to file */
virtual void Save(const std::string & filename, const std::string & name="");
......@@ -60,8 +69,6 @@ public:
/** Load the model from file */
virtual void Load(const std::string & filename, const std::string & name="");
/** Classify all samples in InputListSample and fill TargetListSample with the associated label */
virtual void PredictAll() override;
/**\name Classification model file compatibility tests */
//@{
/** Is the input model file readable and compatible with the corresponding classifier ? */
......@@ -71,13 +78,6 @@ public:
virtual bool CanWriteFile(const std::string &);
//@}
/**\name Confidence accessors for batch mode */
//@{
/** Set the confidence samples (to be used before PredictAll) */
itkSetObjectMacro(ConfidenceListSample,ConfidenceListSampleType);
/** Get the confidence values (to be used after PredictAll) */
itkGetObjectMacro(ConfidenceListSample,ConfidenceListSampleType);
//@}
itkGetMacro(NumberOfTrees,unsigned int);
itkSetMacro(NumberOfTrees,unsigned int);
......@@ -94,10 +94,6 @@ public:
itkGetMacro(ComputeMargin, bool);
itkSetMacro(ComputeMargin, bool);
itkGetMacro(ConfidenceBatchMode, bool);
itkSetMacro(ConfidenceBatchMode, bool);
protected:
/** Constructor */
SharkRandomForestsMachineLearningModel();
......@@ -105,6 +101,12 @@ protected:
/** Destructor */
virtual ~SharkRandomForestsMachineLearningModel();
/** Predict values using the model */
virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=NULL) const ITK_OVERRIDE;
virtual void DoPredictBatch(const InputListSampleType *, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType *, ConfidenceListSampleType * = ITK_NULLPTR) const ITK_OVERRIDE;
/** PrintSelf method */
void PrintSelf(std::ostream& os, itk::Indent indent) const;
......@@ -120,10 +122,8 @@ private:
unsigned int m_NodeSize;
float m_OobRatio;
bool m_ComputeMargin;
bool m_ConfidenceBatchMode;
/** Confidence list sample */
typename ConfidenceListSampleType::Pointer m_ConfidenceListSample;
ConfidenceValueType ComputeConfidence(shark::RealVector probas,
bool computeMargin) const;
......
......@@ -21,7 +21,21 @@
#include <fstream>
#include "itkMacro.h"
#include "otbSharkRandomForestsMachineLearningModel.h"
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#include <shark/Models/Converter.h>
#pragma GCC diagnostic pop
#else
#include <shark/Models/Converter.h>
#endif
#include "otbSharkUtils.h"
#include <algorithm>
......@@ -34,8 +48,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
{
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = false;
this->m_ConfidenceBatchMode = false;
m_ConfidenceListSample = ConfidenceListSampleType::New();
this->m_IsDoPredictBatchMultiThreaded = true;
}
......@@ -93,7 +106,7 @@ template <class TInputValue, class TOutputValue>
typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::Predict(const InputSampleType & value, ConfidenceValueType *quality) const
::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const
{
shark::RealVector samples;
for(size_t i = 0; i < value.Size();i++)
......@@ -117,38 +130,48 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
template <class TInputValue, class TOutputValue>
void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::PredictAll()
::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const
{
// the samples to be predicted have to be set
assert(this->GetInputListSample() != ITK_NULLPTR);
assert(input != ITK_NULLPTR);
assert(targets != ITK_NULLPTR);
assert(input->Size()==targets->Size()&&"Input sample list and target label list do not have the same size.");
assert(((quality==ITK_NULLPTR)||(quality->Size()==input->Size()))&&"Quality samples list is not null and does not have the same size as input samples list");
if(startIndex+size>input->Size())
{
itkExceptionMacro(<<"requested range ["<<startIndex<<", "<<startIndex+size<<"[ partially outside input sample list range.[0,"<<input->Size()<<"[");
}
std::vector<shark::RealVector> features;
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
if(this->m_ConfidenceBatchMode)
auto probas = m_RFModel(inputSamples);
if(quality != ITK_NULLPTR)
{
//the confidence samples have to exist
assert(this->GetConfidenceListSample() != ITK_NULLPTR);
auto probas = m_RFModel(inputSamples);
ConfidenceListSampleType * confidences = this->GetConfidenceListSample();
confidences->Clear();
unsigned int id = startIndex;
for(const auto& p : probas.elements())
{
ConfidenceSampleType confidence;
auto conf = ComputeConfidence(p, m_ComputeMargin);
confidence[0] = static_cast<ConfidenceValueType>(conf);
confidences->PushBack(confidence);
quality->SetMeasurementVector(id,confidence);
++id;
}
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
TargetListSampleType * targets = this->GetTargetListSample();
targets->Clear();
unsigned int id = startIndex;
for(const auto& p : prediction.elements())
{
TargetSampleType target;
target[0] = static_cast<TOutputValue>(p);
targets->PushBack(target);
targets->SetMeasurementVector(id,target);
}
}
......
......@@ -19,33 +19,43 @@
#define __SharkUtils_h
//#include <shark/Algorithms/Trainers/RFTrainer.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wshadow"
#include <shark/Data/Dataset.h>
#pragma GCC diagnostic pop
#else
#include <shark/Data/Dataset.h>
#endif
namespace otb
{
namespace Shark
{
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<shark::RealVector> & output)
template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<shark::RealVector> & output, const unsigned int & start, const unsigned int& size)
{
assert(listSample != ITK_NULLPTR);
assert(start+size<=listSample->Size());
// Sample index
unsigned int sampleIdx = 0;
unsigned int sampleIdx = start;
//Check for valid listSample
if(listSample->Size()>0)
{
// Retrieve samples count
output.clear();
// Build an iterator
typename T::ConstIterator sampleIt = listSample->Begin();
// Retrieve samples size alike
const unsigned int sampleSize = listSample->GetMeasurementVectorSize();
// Fill the output vector
for(;sampleIt != listSample->End();++sampleIt,++sampleIdx)
while(sampleIdx<start+size)
{
// Retrieve sample
typename T::MeasurementVectorType sample = sampleIt.GetMeasurementVector();
typename T::MeasurementVectorType sample = listSample->GetMeasurementVector(sampleIdx);
// Define a shark::RealVector
shark::RealVector rv(sampleSize);
......@@ -55,33 +65,52 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto
rv[i] = sample[i];
}
output.push_back(rv);
++sampleIdx;
}
}
}
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<unsigned int> & output)
template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<unsigned int> & output, const unsigned int & start, const unsigned int & size)
{
assert(listSample != ITK_NULLPTR);
assert(start+size<=listSample->Size());
// Sample index
unsigned int sampleIdx = 0;
unsigned int sampleIdx = start;
//Check for valid listSample
if(listSample->Size()>0)
{
// Retrieve samples count
output.clear();
// Build an iterator
typename T::ConstIterator sampleIt = listSample->Begin();
// Fill the output vector
for(;sampleIt != listSample->End();++sampleIt,++sampleIdx)
while(sampleIdx<start+size)
{
// Retrieve sample
typename T::MeasurementVectorType sample = sampleIt.GetMeasurementVector();
typename T::MeasurementVectorType sample = listSample->GetMeasurementVector(sampleIdx);
// Define a shark::RealVector
output.push_back(sample[0]);
++sampleIdx;
}
}
}
}
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<shark::RealVector> & output)
{
assert(listSample != ITK_NULLPTR);
ListSampleRangeToSharkVector(listSample,output,0U,static_cast<unsigned int>(listSample->Size()));
}
template <class T> void ListSampleToSharkVector(const T * listSample, std::vector<unsigned int> & output)
{
assert(listSample != ITK_NULLPTR);
ListSampleRangeToSharkVector(listSample,output,0, static_cast<unsigned int>(listSample->Size()));
}
}
}
#endif
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment