Commit 16b6132c authored by Victor Poughon's avatar Victor Poughon

Merge branch 'sort_shark_labels' into 'develop'

BUG: sort Shark labels before encoding so that we know the order for the probability image

See merge request !453
parents 67881ae4 8ae44713
Pipeline #967 passed with stage
in 25 minutes and 29 seconds
......@@ -25,8 +25,8 @@
int otbSharkNormalizeLabels(int itkNotUsed(argc), char* itkNotUsed(argv) [])
{
std::vector<unsigned int> inLabels = {2, 2, 3, 20, 1};
std::vector<unsigned int> expectedDictionary = {2, 3, 20, 1};
std::vector<unsigned int> expectedLabels = {0, 0, 1, 2, 3};
std::vector<unsigned int> expectedDictionary = {1, 2, 3, 20};
std::vector<unsigned int> expectedLabels = {1, 1, 2, 3, 0};
auto newLabels = inLabels;
std::vector<unsigned int> labelDict;
......
......@@ -53,6 +53,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
this->m_IsRegressionSupported = false;
this->m_IsDoPredictBatchMultiThreaded = true;
this->m_NormalizeClassLabels = true;
this->m_ComputeMargin = false;
}
......
......@@ -135,13 +135,18 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto
}
/** Shark assumes that labels are 0 ... (nbClasses-1). This function modifies the labels contained in the input vector and returns a vector with size = nbClasses which allows the translation from the normalised labels to the new ones oldLabel = dictionary[newLabel].
When we want to generate the image containing the probability for each class, we need to ensure that the probabilities are in the correct order wrt the incoming labels. We therefore sort the labels before building the encoding.
*/
template <typename T> void NormalizeLabelsAndGetDictionary(std::vector<T>& labels,
std::vector<T>& dictionary)
{
std::vector<T> sorted_labels = labels;
std::sort(std::begin(sorted_labels), std::end(sorted_labels));
auto last = std::unique(std::begin(sorted_labels), std::end(sorted_labels));
sorted_labels.erase(last, std::end(sorted_labels));
std::unordered_map<T, T> dictMap;
T labelCount{0};
for(const auto& l : labels)
for(const auto& l : sorted_labels)
{
if(dictMap.find(l)==dictMap.end())
dictMap.insert({l, labelCount++});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment