Skip to content
Snippets Groups Projects
Commit 16b6132c authored by Victor Poughon's avatar Victor Poughon
Browse files

Merge branch 'sort_shark_labels' into 'develop'

BUG: sort Shark labels before encoding so that we know the order for the probability image

See merge request orfeotoolbox/otb!453
parents 67881ae4 8ae44713
No related branches found
No related tags found
No related merge requests found
......@@ -25,8 +25,8 @@
int otbSharkNormalizeLabels(int itkNotUsed(argc), char* itkNotUsed(argv) [])
{
std::vector<unsigned int> inLabels = {2, 2, 3, 20, 1};
std::vector<unsigned int> expectedDictionary = {2, 3, 20, 1};
std::vector<unsigned int> expectedLabels = {0, 0, 1, 2, 3};
std::vector<unsigned int> expectedDictionary = {1, 2, 3, 20};
std::vector<unsigned int> expectedLabels = {1, 1, 2, 3, 0};
auto newLabels = inLabels;
std::vector<unsigned int> labelDict;
......
......@@ -53,6 +53,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
this->m_IsRegressionSupported = false;
this->m_IsDoPredictBatchMultiThreaded = true;
this->m_NormalizeClassLabels = true;
this->m_ComputeMargin = false;
}
......
......@@ -135,13 +135,18 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto
}
/** Shark assumes that labels are 0 ... (nbClasses-1). This function modifies the labels contained in the input vector and returns a vector with size = nbClasses which allows the translation from the normalised labels to the new ones oldLabel = dictionary[newLabel].
When we want to generate the image containing the probability for each class, we need to ensure that the probabilities are in the correct order wrt the incoming labels. We therefore sort the labels before building the encoding.
*/
template <typename T> void NormalizeLabelsAndGetDictionary(std::vector<T>& labels,
std::vector<T>& dictionary)
{
std::vector<T> sorted_labels = labels;
std::sort(std::begin(sorted_labels), std::end(sorted_labels));
auto last = std::unique(std::begin(sorted_labels), std::end(sorted_labels));
sorted_labels.erase(last, std::end(sorted_labels));
std::unordered_map<T, T> dictMap;
T labelCount{0};
for(const auto& l : labels)
for(const auto& l : sorted_labels)
{
if(dictMap.find(l)==dictMap.end())
dictMap.insert({l, labelCount++});
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment