Skip to content
Snippets Groups Projects
Commit 73d6e308 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

TEST: add test for shark relabelling function

parent 73f79bba
No related branches found
No related tags found
No related merge requests found
......@@ -28,7 +28,10 @@ otb_module(OTBLearningBase
OTBImageBase
OTBITK
TEST_DEPENDS
OPTIONAL_DEPENDS
OTBShark
TEST_DEPENDS
OTBTestKernel
OTBImageIO
......
......@@ -30,9 +30,12 @@ otbSEMClassifierNew.cxx
otbDecisionTreeNew.cxx
otbKMeansImageClassificationFilterNew.cxx
otbMachineLearningModelTemplates.cxx
otbSharkUtilsTests.cxx
)
if(OTB_USE_SHARK)
set(OTBLearningBaseTests ${OTBLearningBaseTests} otbSharkUtilsTests.cxx)
endif()
add_executable(otbLearningBaseTestDriver ${OTBLearningBaseTests})
target_link_libraries(otbLearningBaseTestDriver ${OTBLearningBase-Test_LIBRARIES})
otb_module_target_label(otbLearningBaseTestDriver)
......@@ -69,5 +72,7 @@ otb_add_test(NAME leTuDecisionTreeNew COMMAND otbLearningBaseTestDriver
otb_add_test(NAME leTuKMeansImageClassificationFilterNew COMMAND otbLearningBaseTestDriver
otbKMeansImageClassificationFilterNew)
otb_add_test(NAME leTuSharkNormalizeLabels COMMAND otbLearningBaseTestDriver
otbSharkNormalizeLabels)
if(OTB_USE_SHARK)
otb_add_test(NAME leTuSharkNormalizeLabels COMMAND otbLearningBaseTestDriver
otbSharkNormalizeLabels)
endif()
......@@ -29,5 +29,7 @@ void RegisterTests()
REGISTER_TEST(otbSEMClassifierNew);
REGISTER_TEST(otbDecisionTreeNew);
REGISTER_TEST(otbKMeansImageClassificationFilterNew);
#ifdef OTB_USE_SHARK
REGISTER_TEST(otbSharkNormalizeLabels);
#endif
}
......@@ -23,6 +23,7 @@
#include <stdexcept>
#include <string>
#include <unordered_map>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
......@@ -127,6 +128,27 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto
assert(listSample != nullptr);
ListSampleRangeToSharkVector(listSample,output,0, static_cast<unsigned int>(listSample->Size()));
}
/** Shark assumes that labels are 0 ... (nbClasses-1). This function modifies the labels contained in the input vector and returns a vector with size = nbClasses which allows the translation from the normalised labels to the new ones oldLabel = dictionary[newLabel].
*/
template <typename T> void NormalizeLabelsAndGetDictionary(std::vector<T>& labels,
std::vector<T>& dictionary)
{
std::unordered_map<T, T> dictMap;
T labelCount{0};
for(const auto& l : labels)
{
if(dictMap.find(l)==dictMap.end())
dictMap.insert({l, labelCount++});
}
dictionary.resize(labelCount);
for(auto& l : labels)
{
auto newLabel = dictMap[l];
dictionary[newLabel] = l;
l = newLabel;
}
}
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment