Commit 88238340 authored by Jordi Inglada's avatar Jordi Inglada

ENH: confidence in batch mode implemented

parent 4d812b33
......@@ -94,6 +94,9 @@ public:
itkGetMacro(ComputeMargin, bool);
itkSetMacro(ComputeMargin, bool);
itkGetMacro(ConfidenceBatchMode, bool);
itkSetMacro(ConfidenceBatchMode, bool);
protected:
/** Constructor */
......@@ -117,11 +120,14 @@ private:
unsigned int m_NodeSize;
float m_OobRatio;
bool m_ComputeMargin;
bool m_ConfidenceBatchMode;
/** Confidence list sample */
typename ConfidenceListSampleType::Pointer m_ConfidenceListSample;
ConfidenceValueType ComputeConfidence(shark::RealVector probas,
bool computeMargin) const;
};
};
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
......
......@@ -33,7 +33,9 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::SharkRandomForestsMachineLearningModel()
{
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = true;
this->m_IsRegressionSupported = false;
this->m_ConfidenceBatchMode = false;
m_ConfidenceListSample = ConfidenceListSampleType::New();
}
......@@ -65,6 +67,28 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
template <class TInputValue, class TOutputValue>
typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::ConfidenceValueType
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::ComputeConfidence(shark::RealVector probas, bool computeMargin) const
{
ConfidenceValueType conf{0};
if(computeMargin)
{
std::nth_element(probas.begin(), probas.begin()+1,
probas.end(), std::greater<double>());
conf = static_cast<ConfidenceValueType>(probas[0]-probas[1]);
}
else
{
auto max_proba = *(std::max_element(probas.begin(),
probas.end()));
conf = static_cast<ConfidenceValueType>(max_proba);
}
return conf;
}
template <class TInputValue, class TOutputValue>
typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
......@@ -79,18 +103,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
if (quality != NULL)
{
auto probas = m_RFModel(samples);
if(m_ComputeMargin)
{
std::nth_element(probas.begin(), probas.begin()+1,
probas.end(), std::greater<double>());
(*quality) = static_cast<ConfidenceValueType>(probas[0]-probas[1]);
}
else
{
auto max_proba = *(std::max_element(probas.begin(),
probas.end()));
(*quality) = static_cast<ConfidenceValueType>(max_proba);
}
(*quality) = ComputeConfidence(probas, m_ComputeMargin);
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
......@@ -106,11 +119,22 @@ void
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::PredictAll()
{
// TODO : compute confidence using batches needs to change the api of the model, so te confidences can be stored in a vector as the labels for the predict all
std::vector<shark::RealVector> features;
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
if(this->m_ConfidenceBatchMode)
{
auto probas = m_RFModel(inputSamples);
ConfidenceListSampleType * confidences = this->GetConfidenceListSample();
confidences->Clear();
for(const auto& p : probas.elements())
{
ConfidenceSampleType confidence;
auto conf = ComputeConfidence(p, m_ComputeMargin);
confidence[0] = static_cast<ConfidenceValueType>(conf);
confidences->PushBack(confidence);
}
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment