Commit d97a5e7b authored by Julien Michel's avatar Julien Michel

ENH: Apply comments from RFC review

parent 42f65efa
......@@ -124,7 +124,7 @@ private:
bool m_ComputeMargin;
/** Confidence list sample */
ConfidenceValueType ComputeConfidence(shark::RealVector probas,
ConfidenceValueType ComputeConfidence(shark::RealVector & probas,
bool computeMargin) const;
};
......
......@@ -88,8 +88,11 @@ template <class TInputValue, class TOutputValue>
typename SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::ConfidenceValueType
SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
::ComputeConfidence(shark::RealVector probas, bool computeMargin) const
::ComputeConfidence(shark::RealVector & probas, bool computeMargin) const
{
assert(!probas.empty()&&"probas vector is empty");
assert((!computeMargin||probas.size()>1)&&"probas size should be at least 2 if computeMargin is true");
ConfidenceValueType conf{0};
if(computeMargin)
{
......@@ -119,7 +122,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
if (quality != ITK_NULLPTR)
{
auto probas = m_RFModel(samples);
shark::RealVector probas = m_RFModel(samples);
(*quality) = ComputeConfidence(probas, m_ComputeMargin);
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
......@@ -155,12 +158,12 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
auto probas = m_RFModel(inputSamples);
shark::Data<shark::RealVector> probas = m_RFModel(inputSamples);
if(quality != ITK_NULLPTR)
{
unsigned int id = startIndex;
for(const auto& p : probas.elements())
for(shark::RealVector && p : probas.elements())
{
ConfidenceSampleType confidence;
auto conf = ComputeConfidence(p, m_ComputeMargin);
......
......@@ -33,21 +33,23 @@ namespace otb
{
namespace Shark
{
template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<shark::RealVector> & output, const unsigned int & start, const unsigned int& size)
template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<shark::RealVector> & output, unsigned int start, unsigned int size)
{
assert(listSample != ITK_NULLPTR);
assert(start+size<=listSample->Size());
if(start+size>listSample->Size())
{
itkGenericExceptionMacro(<<"Requested range ["<<start<<", "<<start+size<<"[ is out of bound for input list sample (range [0, "<<listSample->Size()<<"[");
}
output.clear();
// Sample index
unsigned int sampleIdx = start;
//Check for valid listSample
if(listSample->Size()>0)
{
// Retrieve samples count
output.clear();
// Retrieve samples size alike
const unsigned int sampleSize = listSample->GetMeasurementVectorSize();
......@@ -56,7 +58,7 @@ template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::
for (auto const endOfRange = start+size ; sampleIdx < endOfRange ; ++sampleIdx)
{
// Retrieve sample
typename T::MeasurementVectorType sample = listSample->GetMeasurementVector(sampleIdx);
typename T::MeasurementVectorType const & sample = listSample->GetMeasurementVector(sampleIdx);
// Define a shark::RealVector
shark::RealVector rv(sampleSize);
......@@ -65,30 +67,34 @@ template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::
{
rv[i] = sample[i];
}
output.push_back(rv);
using std::move;
output.emplace_back(move(rv));
}
}
}
template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<unsigned int> & output, const unsigned int & start, const unsigned int & size)
template <class T> void ListSampleRangeToSharkVector(const T * listSample, std::vector<unsigned int> & output, unsigned int start, unsigned int size)
{
assert(listSample != ITK_NULLPTR);
assert(start+size<=listSample->Size());
if(start+size>listSample->Size())
{
itkGenericExceptionMacro(<<"Requested range ["<<start<<", "<<start+size<<"[ is out of bound for input list sample (range [0, "<<listSample->Size()<<"[");
}
output.clear();
// Sample index
unsigned int sampleIdx = start;
//Check for valid listSample
if(listSample->Size()>0)
{
// Retrieve samples count
output.clear();
// Fill the output vector
while(sampleIdx<start+size)
{
// Retrieve sample
typename T::MeasurementVectorType sample = listSample->GetMeasurementVector(sampleIdx);
typename T::MeasurementVectorType const & sample = listSample->GetMeasurementVector(sampleIdx);
// Define a shark::RealVector
output.push_back(sample[0]);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment