Commit 81bed5ae authored by Victor Poughon's avatar Victor Poughon

Merge remote-tracking branch 'origin/revert_shark' into release-6.6

parents f3749c05 b68d86d8
......@@ -291,15 +291,12 @@ protected:
itkExceptionMacro(<< "File : " << modelFileName << " couldn't be opened");
}
// get the line with the centroids (starts with "2 ")
// get the end line with the centroids
std::string line, centroidLine;
while(std::getline(infile,line))
{
if (line.size() > 2 && line[0] == '2' && line[1] == ' ')
{
if (!line.empty())
centroidLine = line;
break;
}
}
std::vector<std::string> centroidElm;
......
......@@ -33,9 +33,8 @@
#endif
#include "otb_shark.h"
#include <shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h>
#include <shark/Models/LinearModel.h>
#include <shark/Models/ConcatenatedModel.h>
#include <shark/Models/NeuronLayers.h>
#include <shark/Models/FFNet.h>
#include <shark/Models/Autoencoder.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
......@@ -77,9 +76,9 @@ public:
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
/// Neural network related typedefs
typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
typedef shark::LinearModel<shark::RealVector, shark::LinearNeuron> OutLayerType;
typedef shark::Autoencoder<NeuronType,shark::LinearNeuron> OutAutoencoderType;
typedef shark::Autoencoder<NeuronType,NeuronType> AutoencoderType;
typedef shark::FFNet<NeuronType,shark::LinearNeuron> NetworkType;
itkNewMacro(Self);
itkTypeMacro(AutoencoderModel, DimensionalityReductionModel);
......@@ -128,16 +127,18 @@ public:
void Train() override;
template <class T>
template <class T, class Autoencoder>
void TrainOneLayer(
shark::AbstractStoppingCriterion<T> & criterion,
Autoencoder &,
unsigned int,
shark::Data<shark::RealVector> &,
std::ostream&);
template <class T>
template <class T, class Autoencoder>
void TrainOneSparseLayer(
shark::AbstractStoppingCriterion<T> & criterion,
Autoencoder &,
unsigned int,
shark::Data<shark::RealVector> &,
std::ostream&);
......@@ -165,9 +166,7 @@ protected:
private:
/** Internal Network */
ModelType m_Encoder;
std::vector<LayerType> m_InLayers;
OutLayerType m_OutLayer;
NetworkType m_Net;
itk::Array<unsigned int> m_NumberOfHiddenNeurons;
/** Training parameters */
unsigned int m_NumberOfIterations; // stop the training after a fixed number of iterations
......
......@@ -137,11 +137,11 @@ PCAModel<TInputValue>::Load(const std::string & filename, const std::string & /*
ifs.close();
if (this->m_Dimension ==0)
{
this->m_Dimension = m_Encoder.outputShape()[0];
this->m_Dimension = m_Encoder.outputSize();
}
auto eigenvectors = m_Encoder.matrix();
eigenvectors.resize(this->m_Dimension,m_Encoder.inputShape()[0]);
eigenvectors.resize(this->m_Dimension,m_Encoder.inputSize());
m_Encoder.setStructure(eigenvectors, m_Encoder.offset() );
}
......
......@@ -28,11 +28,7 @@ otb_module(OTBLearningBase
OTBImageBase
OTBITK
OPTIONAL_DEPENDS
OTBShark
TEST_DEPENDS
OTBBoost
TEST_DEPENDS
OTBTestKernel
OTBImageIO
......
......@@ -32,10 +32,6 @@ otbKMeansImageClassificationFilterNew.cxx
otbMachineLearningModelTemplates.cxx
)
if(OTB_USE_SHARK)
set(OTBLearningBaseTests ${OTBLearningBaseTests} otbSharkUtilsTests.cxx)
endif()
add_executable(otbLearningBaseTestDriver ${OTBLearningBaseTests})
target_link_libraries(otbLearningBaseTestDriver ${OTBLearningBase-Test_LIBRARIES})
otb_module_target_label(otbLearningBaseTestDriver)
......@@ -72,7 +68,3 @@ otb_add_test(NAME leTuDecisionTreeNew COMMAND otbLearningBaseTestDriver
otb_add_test(NAME leTuKMeansImageClassificationFilterNew COMMAND otbLearningBaseTestDriver
otbKMeansImageClassificationFilterNew)
if(OTB_USE_SHARK)
otb_add_test(NAME leTuSharkNormalizeLabels COMMAND otbLearningBaseTestDriver
otbSharkNormalizeLabels)
endif()
......@@ -29,7 +29,4 @@ void RegisterTests()
REGISTER_TEST(otbSEMClassifierNew);
REGISTER_TEST(otbDecisionTreeNew);
REGISTER_TEST(otbKMeansImageClassificationFilterNew);
#ifdef OTB_USE_SHARK
REGISTER_TEST(otbSharkNormalizeLabels);
#endif
}
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "itkMacro.h"
#include "otbSharkUtils.h"
int otbSharkNormalizeLabels(int itkNotUsed(argc), char* itkNotUsed(argv) [])
{
std::vector<unsigned int> inLabels = {2, 2, 3, 20, 1};
std::vector<unsigned int> expectedDictionary = {2, 3, 20, 1};
std::vector<unsigned int> expectedLabels = {0, 0, 1, 2, 3};
auto newLabels = inLabels;
std::vector<unsigned int> labelDict;
otb::Shark::NormalizeLabelsAndGetDictionary(newLabels, labelDict);
if(newLabels != expectedLabels)
{
std::cout << "Wrong new labels\n";
for(size_t i = 0; i<newLabels.size(); ++i)
std::cout << "Got " << newLabels[i] << " expected " << expectedLabels[i] << '\n';
return EXIT_FAILURE;
}
if(labelDict != expectedDictionary)
{
std::cout << "Wrong dictionary\n";
for(size_t i = 0; i<labelDict.size(); ++i)
std::cout << "Got " << labelDict[i] << " expected " << expectedDictionary[i] << '\n';
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
......@@ -36,7 +36,6 @@
#pragma GCC diagnostic ignored "-Wheader-guard"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#endif
#include <shark/Models/Classifier.h>
#include "otb_shark.h"
#include "shark/Algorithms/Trainers/RFTrainer.h"
#if defined(__GNUC__) || defined(__clang__)
......@@ -137,10 +136,6 @@ public:
/** If true, margin confidence value will be computed */
itkSetMacro(ComputeMargin, bool);
/** If true, class labels will be normalised in [0 ... nbClasses] */
itkGetMacro(NormalizeClassLabels, bool);
itkSetMacro(NormalizeClassLabels, bool);
protected:
/** Constructor */
SharkRandomForestsMachineLearningModel();
......@@ -161,10 +156,8 @@ private:
SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
shark::RFClassifier<unsigned int> m_RFModel;
shark::RFTrainer<unsigned int> m_RFTrainer;
std::vector<unsigned int> m_ClassDictionary;
bool m_NormalizeClassLabels;
shark::RFClassifier m_RFModel;
shark::RFTrainer m_RFTrainer;
unsigned int m_NumberOfTrees;
unsigned int m_MTry;
......
......@@ -32,6 +32,7 @@
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#endif
#include <shark/Models/Converter.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
......@@ -51,7 +52,6 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = false;
this->m_IsDoPredictBatchMultiThreaded = true;
this->m_NormalizeClassLabels = true;
}
......@@ -76,17 +76,13 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
if(m_NormalizeClassLabels)
{
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
}
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
m_RFTrainer.setMTry(m_MTry);
m_RFTrainer.setNTrees(m_NumberOfTrees);
m_RFTrainer.setNodeSize(m_NodeSize);
// m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.train(m_RFModel, TrainSamples);
}
......@@ -129,20 +125,15 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
if (quality != ITK_NULLPTR)
{
shark::RealVector probas = m_RFModel.decisionFunction()(samples);
shark::RealVector probas = m_RFModel(samples);
(*quality) = ComputeConfidence(probas, m_ComputeMargin);
}
unsigned int res{0};
m_RFModel.eval(samples, res);
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
unsigned int res;
amc.eval(samples, res);
TargetSampleType target;
if(m_NormalizeClassLabels)
{
target[0] = m_ClassDictionary[static_cast<TOutputValue>(res)];
}
else
{
target[0] = static_cast<TOutputValue>(res);
}
target[0] = static_cast<TOutputValue>(res);
return target;
}
......@@ -166,13 +157,13 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
#ifdef _OPENMP
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
#endif
if(quality != ITK_NULLPTR)
{
shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples);
shark::Data<shark::RealVector> probas = m_RFModel(inputSamples);
unsigned int id = startIndex;
for(shark::RealVector && p : probas.elements())
{
......@@ -184,19 +175,14 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
}
auto prediction = m_RFModel(inputSamples);
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
unsigned int id = startIndex;
for(const auto& p : prediction.elements())
{
TargetSampleType target;
if(m_NormalizeClassLabels)
{
target[0] = m_ClassDictionary[static_cast<TOutputValue>(p)];
}
else
{
target[0] = static_cast<TOutputValue>(p);
}
target[0] = static_cast<TOutputValue>(p);
targets->SetMeasurementVector(id,target);
++id;
}
......@@ -213,18 +199,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
itkExceptionMacro(<< "Error opening " << filename.c_str() );
}
// Add comment with model file name
ofs << "#" << m_RFModel.name();
if(m_NormalizeClassLabels) ofs << " with_dictionary";
ofs << std::endl;
if(m_NormalizeClassLabels)
{
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
{
ofs << l << " ";
}
ofs << std::endl;
}
ofs << "#" << m_RFModel.name() << std::endl;
shark::TextOutArchive oa(ofs);
m_RFModel.save(oa,0);
}
......@@ -244,10 +219,6 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
{
if( line.find( m_RFModel.name() ) == std::string::npos )
itkExceptionMacro( "The model file : " + filename + " cannot be read." );
if( line.find( "with_dictionary" ) == std::string::npos )
{
m_NormalizeClassLabels=false;
}
}
else
{
......@@ -255,18 +226,6 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
ifs.clear();
ifs.seekg( 0, std::ios::beg );
}
if(m_NormalizeClassLabels)
{
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
{
unsigned int label;
ifs >> label;
m_ClassDictionary[i]=label;
}
}
shark::TextInArchive ia( ifs );
m_RFModel.load( ia, 0 );
}
......
......@@ -55,7 +55,6 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
m_Normalized( false ), m_K(2), m_MaximumNumberOfIterations( 10 )
{
// Default set HardClusteringModel
this->m_ConfidenceIndex = true;
m_ClusteringModel = boost::make_shared<ClusteringModelType>( &m_Centroids );
}
......@@ -175,7 +174,7 @@ SharkKMeansMachineLearningModel<TInputValue, TOutputValue>
// Change quality measurement only if SoftClustering or other clustering method is used.
if( quality != ITK_NULLPTR )
{
for( unsigned int qid = startIndex; qid < startIndex+size; ++qid )
for( unsigned int qid = startIndex; qid < size; ++qid )
{
quality->SetMeasurementVector( qid, static_cast<ConfidenceValueType>(1.) );
}
......
......@@ -23,7 +23,6 @@
#include <stdexcept>
#include <string>
#include <unordered_map>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic push
......@@ -128,27 +127,6 @@ template <class T> void ListSampleToSharkVector(const T * listSample, std::vecto
assert(listSample != nullptr);
ListSampleRangeToSharkVector(listSample,output,0, static_cast<unsigned int>(listSample->Size()));
}
/** Shark assumes that labels are 0 ... (nbClasses-1). This function modifies the labels contained in the input vector and returns a vector with size = nbClasses which allows the translation from the normalised labels to the new ones oldLabel = dictionary[newLabel].
*/
template <typename T> void NormalizeLabelsAndGetDictionary(std::vector<T>& labels,
std::vector<T>& dictionary)
{
std::unordered_map<T, T> dictMap;
T labelCount{0};
for(const auto& l : labels)
{
if(dictMap.find(l)==dictMap.end())
dictMap.insert({l, labelCount++});
}
dictionary.resize(labelCount);
for(auto& l : labels)
{
auto newLabel = dictMap[l];
dictionary[newLabel] = l;
l = newLabel;
}
}
}
}
......
......@@ -30,8 +30,8 @@ ADD_SUPERBUILD_CMAKE_VAR(SHARK BOOST_LIBRARYDIR)
ExternalProject_Add(SHARK
PREFIX SHARK
URL "https://github.com/Shark-ML/Shark/archive/67990bcd2c4a90a27be97d933b3740931e9da141.zip"
URL_MD5 9ad7480a4f9832b63086b9a683566187
URL "https://github.com/Shark-ML/Shark/archive/v3.1.4.zip"
URL_MD5 149e7d2e458cbe65c6373c2e89876b3e
SOURCE_DIR ${SHARK_SB_SRC}
BINARY_DIR ${SHARK_SB_BUILD_DIR}
INSTALL_DIR ${SB_INSTALL_PREFIX}
......
diff -burN Shark.orig/CMakeLists.txt Shark/CMakeLists.txt
--- Shark.orig/CMakeLists.txt 2018-02-05 18:04:58.012612932 +0100
+++ Shark/CMakeLists.txt 2018-02-05 18:20:50.032233165 +0100
@@ -415,6 +415,9 @@
#####################################################################
# General Path settings
#####################################################################
+if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
+ add_definitions(-fext-numeric-literals)
+endif()
include_directories( ${shark_SOURCE_DIR}/include )
include_directories( ${shark_BINARY_DIR}/include )
add_subdirectory( include )
diff -burN Shark-349f29bd71c370e0f88f7fc9aa66fa5c4768fcb0.orig/CMakeLists.txt Shark-349f29bd71c370e0f88f7fc9aa66fa5c4768fcb0/CMakeLists.txt
--- Shark-349f29bd71c370e0f88f7fc9aa66fa5c4768fcb0.orig/CMakeLists.txt 2017-08-22 11:31:50.472052695 +0200
+++ Shark-349f29bd71c370e0f88f7fc9aa66fa5c4768fcb0/CMakeLists.txt 2017-08-22 11:32:36.448358789 +0200
@@ -141,10 +141,8 @@
find_package(
Boost 1.48.0 REQUIRED COMPONENTS
- system date_time filesystem
- program_options serialization thread
- unit_test_framework
-)
+ serialization
+ )
if(NOT Boost_FOUND)
message(FATAL_ERROR "Please make sure Boost 1.48.0 is installed on your system")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment