Commit ed47e3a2 authored by Manuel Grizonnet's avatar Manuel Grizonnet

Merge branch 'develop' into 1552-contribute-md-does-not-mention-ccla-and-icla

parents de6ef391 cc1d161c
......@@ -277,6 +277,12 @@ macro(otb_module_test)
foreach(dep IN LISTS OTB_MODULE_${otb-module-test}_DEPENDS)
list(APPEND ${otb-module-test}_LIBRARIES "${${dep}_LIBRARIES}")
endforeach()
# make sure the test can link with optional libs
foreach(dep IN LISTS OTB_MODULE_${otb-module}_OPTIONAL_DEPENDS)
if (${dep}_ENABLED)
list(APPEND ${otb-module-test}_LIBRARIES "${${dep}_LIBRARIES}")
endif()
endforeach()
endmacro()
macro(otb_module_warnings_disable)
......
......@@ -58,6 +58,13 @@ then send a merge request.
Note that we also accept PRs on our [GitHub mirror](https://github.com/orfeotoolbox/OTB)
which we will manually merge.
Feature branches are tested on multiple platforms on the OTB test infrastructure (a.k.a the [Dashboard](https://dash.orfeo-toolbox.org/)). They appear in the FeatureBranches section.
Caveat: even if the Dashboard build on develop branch is broken, it is not
allowed to push fixes directly on develop. The developer trying to fix the
build should create a merge request and submit it for review. Direct push to
develop without review must be avoided.
### Commit message
On your feature branch, write a good [commit message](https://xkcd.com/1296/):
......@@ -93,7 +100,11 @@ OTB team.
* Merge requests **must receive at least 2 positives votes from core developers** (members of Main Repositories group in Gitlab with at least "Developer" level; this includes PSC members) before being merged
* The merger is responsible for checking that the branch is up-to-date with develop
* Merge requests can be merged by anyone (not just PSC or RM) with push access to develop
* Merge requests can be merged once the dashboard is proven green for this branch
* Merge requests can be merged once the dashboard is proven green for this branch.
This condition is mandatory unless reviewers and authors explicitely agree that
it can be skipped (for instance in case of documentation merges or compilation
fixes on develop). Branches of that sort can be identified with the ~patch label,
which tells the reviewer that the author would like to merge without dashboard testing.
Branches can be registered for dashboard testing by adding one line in `Config/feature_branches.txt` in [otb-devutils repository](https://gitlab.orfeo-toolbox.org/orfeotoolbox/otb-devutils.git).
......@@ -162,6 +173,7 @@ Regarding labels, we use the following set:
correspond to a Request for Comments that has turned into a development action
* ~bug: Bug, crash or unexpected behavior, reported by a user or a developer
* ~feature: Feature request expressed by an OTB user/developer
* ~patch: A small patch fixing build warnings, compilation errors, typos in logs or documentation
* ~"To Do": action is planned
* ~Doing: work in progress
* ~api ~app ~documentation ~monteverdi ~packaging ~qgis: optional context information
......@@ -211,7 +211,7 @@ latex_elements = {
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
('index_TOC', 'CookBook-@OTB_VERSION_MAJOR@.@OTB_VERSION_MINOR@.tex', u'OTB CookBook Documentation',
('index_TOC', 'CookBook-@OTB_VERSION_MAJOR@.@OTB_VERSION_MINOR@.@OTB_VERSION_PATCH@.tex', u'OTB CookBook Documentation',
u'OTB Team', 'manual'),
]
......
......@@ -291,12 +291,15 @@ protected:
itkExceptionMacro(<< "File : " << modelFileName << " couldn't be opened");
}
// get the end line with the centroids
// get the line with the centroids (starts with "2 ")
std::string line, centroidLine;
while(std::getline(infile,line))
{
if (!line.empty())
if (line.size() > 2 && line[0] == '2' && line[1] == ' ')
{
centroidLine = line;
break;
}
}
std::vector<std::string> centroidElm;
......
......@@ -51,7 +51,7 @@ public:
private:
SampleAugmentation() {}
void DoInit()
void DoInit() override
{
SetName("SampleAugmentation");
SetDescription("Generates synthetic samples from a sample data file.");
......@@ -145,7 +145,7 @@ private:
SetOfficialDocLink();
}
void DoUpdateParameters()
void DoUpdateParameters() override
{
if ( HasValue("in") )
{
......@@ -182,7 +182,7 @@ private:
}
}
void DoExecute()
void DoExecute() override
{
ogr::DataSource::Pointer vectors;
ogr::DataSource::Pointer output;
......
......@@ -124,12 +124,12 @@ protected:
private:
void DoInit() override
{
SetName("DimensionalityReduction");
SetName("ImageDimensionalityReduction");
SetDescription("Performs dimensionality reduction of the input image "
"according to a dimensionality reduction model file.");
// Documentation
SetDocName("DimensionalityReduction");
SetDocName("Image Dimensionality Reduction");
SetDocLongDescription("This application reduces the dimension of an input"
" image, based on a machine learning model file produced by"
" the TrainDimensionalityReduction application. Pixels of the "
......
......@@ -87,7 +87,7 @@ otb_test_application(NAME apTvRaOpticalCalibration_UnknownSensor
-acqui.sun.elev 62.7
-acqui.sun.azim 152.7
-acqui.view.elev 87.5
-acqui.view.azim -77.0
-acqui.view.azim 283
-acqui.solarilluminations ${INPUTDATA}/apTvRaOpticalCalibrationUnknownSensorSolarIllumations2.txt
-atmo.rsr ${INPUTDATA}/apTvRaOpticalCalibrationUnknownSensorRSR.txt
-atmo.pressure 1013.0
......
......@@ -59,8 +59,15 @@ public:
// Overwrite this to provide custom formatting of log entries
std::string BuildFormattedEntry(itk::Logger::PriorityLevelType, std::string const&) override;
/** Output logs about the RAM, caching and multi-threading settings */
void LogSetupInformation();
/** Return true if the LogSetupInformation has already been called*/
bool IsLogSetupInformationDone();
/** Set the flag m_LogSetupInfoDone to true */
void LogSetupInformationDone();
protected:
Logger();
virtual ~Logger() ITK_OVERRIDE;
......@@ -71,6 +78,8 @@ private:
static Pointer CreateInstance();
bool m_LogSetupInfoDone;
}; // class Logger
} // namespace otb
......
......@@ -38,9 +38,6 @@ Logger::Pointer Logger::CreateInstance()
defaultOutput->SetStream(std::cout);
instance->AddLogOutput(defaultOutput);
// Log setup information
instance->LogSetupInformation();
return instance;
}
......@@ -61,6 +58,8 @@ Logger::Logger()
this->SetTimeStampFormat(itk::LoggerBase::HUMANREADABLE);
this->SetHumanReadableFormat("%Y-%m-%d %H:%M:%S");
m_LogSetupInfoDone = false;
}
Logger::~Logger()
......@@ -69,22 +68,29 @@ Logger::~Logger()
void Logger::LogSetupInformation()
{
std::ostringstream oss;
oss<<"Default RAM limit for OTB is "<<otb::ConfigurationManager::GetMaxRAMHint()<<" MB"<<std::endl;
this->Info(oss.str());
oss.str("");
oss.clear();
oss<<"GDAL maximum cache size is "<<GDALGetCacheMax64()/(1024*1024)<<" MB"<<std::endl;
this->Info(oss.str());
oss.str("");
oss.clear();
oss<<"OTB will use at most "<<itk::MultiThreader::GetGlobalDefaultNumberOfThreads()<<" threads"<<std::endl;
this->Info(oss.str());
oss.str("");
oss.clear();
if (! IsLogSetupInformationDone())
{
std::ostringstream oss;
oss<<"Default RAM limit for OTB is "<<otb::ConfigurationManager::GetMaxRAMHint()<<" MB"<<std::endl;
this->Info(oss.str());
oss.str("");
oss.clear();
oss<<"GDAL maximum cache size is "<<GDALGetCacheMax64()/(1024*1024)<<" MB"<<std::endl;
this->Info(oss.str());
oss.str("");
oss.clear();
oss<<"OTB will use at most "<<itk::MultiThreader::GetGlobalDefaultNumberOfThreads()<<" threads"<<std::endl;
this->Info(oss.str());
oss.str("");
oss.clear();
// only switch the flag for the singleton, so that other instances can call
// LogSetupInformation() several times
Instance()->LogSetupInformationDone();
}
}
std::string Logger::BuildFormattedEntry(itk::Logger::PriorityLevelType level, std::string const & content)
......@@ -116,4 +122,14 @@ std::string Logger::BuildFormattedEntry(itk::Logger::PriorityLevelType level, st
return s.str();
}
bool Logger::IsLogSetupInformationDone()
{
return m_LogSetupInfoDone;
}
void Logger::LogSetupInformationDone()
{
m_LogSetupInfoDone = true;
}
} // namespace otb
......@@ -183,6 +183,8 @@ void
StreamingImageVirtualWriter<TInputImage>
::GenerateData(void)
{
otb::Logger::Instance()->LogSetupInformation();
/**
* Prepare all the outputs. This may deallocate previous bulk data.
*/
......
......@@ -156,7 +156,7 @@ StreamingManager<TImage>::EstimateOptimalNumberOfDivisions(itk::DataObject * inp
unsigned int optimalNumberOfDivisions =
otb::PipelineMemoryPrintCalculator::EstimateOptimalNumberOfStreamDivisions(pipelineMemoryPrint, availableRAMInBytes);
otbLogMacro(Info,<<"Estimated memory for full processing: "<<pipelineMemoryPrint * otb::PipelineMemoryPrintCalculator::ByteToMegabyte<<"MB (avail.: "<<availableRAMInBytes * otb::PipelineMemoryPrintCalculator::ByteToMegabyte<<" NB), optimal image partitioning: "<<optimalNumberOfDivisions<<" blocks");
otbLogMacro(Info,<<"Estimated memory for full processing: "<<pipelineMemoryPrint * otb::PipelineMemoryPrintCalculator::ByteToMegabyte<<"MB (avail.: "<<availableRAMInBytes * otb::PipelineMemoryPrintCalculator::ByteToMegabyte<<" MB), optimal image partitioning: "<<optimalNumberOfDivisions<<" blocks");
return optimalNumberOfDivisions;
}
......
......@@ -591,7 +591,7 @@ PersistentStreamingStatisticsVectorImageFilter<TInputImage, TPrecision>
{
for (unsigned int c = 0; c < threadSecondOrder.Cols(); ++c)
{
threadSecondOrder(r, c) += vectorValue[r] * vectorValue[c];
threadSecondOrder(r, c) += static_cast<PrecisionType>(vectorValue[r]) * static_cast<PrecisionType>(vectorValue[c]);
}
}
threadSecondOrderComponent += vectorValue.GetSquaredNorm();
......
......@@ -380,6 +380,17 @@ ImageFileReader<TOutputImage, ConvertPixelTraits>
spacing[i] = 1.0;
}
origin[i] = 0.5*spacing[i];
for (unsigned j = 0; j < TOutputImage::ImageDimension; ++j)
{
if (i == j)
{
direction[j][i] = 1.0;
}
else
{
direction[j][i] = 0.0;
}
}
}
}
......
......@@ -279,6 +279,8 @@ ImageFileWriter<TInputImage>
itkExceptionMacro(<< "No input to writer");
}
otb::Logger::Instance()->LogSetupInformation();
/** Parse streaming modes */
if(m_FilenameHelper->StreamingTypeIsSet())
{
......
......@@ -41,6 +41,8 @@ namespace otb
* When the user gives a number of lines per strip or a tile size, the value
* is interpreted on the first input to deduce the number of streams. This
* number of streams is then used to split the other inputs.
*
* \ingroup OTBImageIO
*/
class OTBImageIO_EXPORT MultiImageFileWriter: public itk::ProcessObject
{
......@@ -226,7 +228,11 @@ private:
bool m_IsObserving;
unsigned long m_ObserverID;
/** Internal base wrapper class to handle each ImageFileWriter */
/** \class SinkBase
* Internal base wrapper class to handle each ImageFileWriter
*
* \ingroup OTBImageIO
*/
class SinkBase
{
public:
......@@ -248,6 +254,8 @@ private:
/** \class Sink
* Wrapper class for each ImageFileWriter
*
* \ingroup OTBImageIO
*/
template <class TImage>
class Sink : public SinkBase
......
......@@ -28,7 +28,7 @@
#include <iostream>
#include "itkMultiThreader.h"
#include "itkMacro.h"
#include "otbMacro.h"
#include "otbOGRDriversInit.h"
#include "otbTestHelper.h"
......@@ -298,6 +298,7 @@ int main(int ac, char* av[])
}
else
{
otb::Logger::Instance()->LogSetupInformation();
MainFuncPointer f = j->second;
int result;
try
......
......@@ -33,8 +33,9 @@
#endif
#include "otb_shark.h"
#include <shark/Algorithms/StoppingCriteria/AbstractStoppingCriterion.h>
#include <shark/Models/FFNet.h>
#include <shark/Models/Autoencoder.h>
#include <shark/Models/LinearModel.h>
#include <shark/Models/ConcatenatedModel.h>
#include <shark/Models/NeuronLayers.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
......@@ -76,9 +77,9 @@ public:
typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType;
/// Neural network related typedefs
typedef shark::Autoencoder<NeuronType,shark::LinearNeuron> OutAutoencoderType;
typedef shark::Autoencoder<NeuronType,NeuronType> AutoencoderType;
typedef shark::FFNet<NeuronType,shark::LinearNeuron> NetworkType;
typedef shark::ConcatenatedModel<shark::RealVector> ModelType;
typedef shark::LinearModel<shark::RealVector,NeuronType> LayerType;
typedef shark::LinearModel<shark::RealVector, shark::LinearNeuron> OutLayerType;
itkNewMacro(Self);
itkTypeMacro(AutoencoderModel, DimensionalityReductionModel);
......@@ -127,18 +128,16 @@ public:
void Train() override;
template <class T, class Autoencoder>
template <class T>
void TrainOneLayer(
shark::AbstractStoppingCriterion<T> & criterion,
Autoencoder &,
unsigned int,
shark::Data<shark::RealVector> &,
std::ostream&);
template <class T, class Autoencoder>
template <class T>
void TrainOneSparseLayer(
shark::AbstractStoppingCriterion<T> & criterion,
Autoencoder &,
unsigned int,
shark::Data<shark::RealVector> &,
std::ostream&);
......@@ -166,7 +165,9 @@ protected:
private:
/** Internal Network */
NetworkType m_Net;
ModelType m_Encoder;
std::vector<LayerType> m_InLayers;
OutLayerType m_OutLayer;
itk::Array<unsigned int> m_NumberOfHiddenNeurons;
/** Training parameters */
unsigned int m_NumberOfIterations; // stop the training after a fixed number of iterations
......
......@@ -137,11 +137,11 @@ PCAModel<TInputValue>::Load(const std::string & filename, const std::string & /*
ifs.close();
if (this->m_Dimension ==0)
{
this->m_Dimension = m_Encoder.outputSize();
this->m_Dimension = m_Encoder.outputShape()[0];
}
auto eigenvectors = m_Encoder.matrix();
eigenvectors.resize(this->m_Dimension,m_Encoder.inputSize());
eigenvectors.resize(this->m_Dimension,m_Encoder.inputShape()[0]);
m_Encoder.setStructure(eigenvectors, m_Encoder.offset() );
}
......
......@@ -28,7 +28,11 @@ otb_module(OTBLearningBase
OTBImageBase
OTBITK
TEST_DEPENDS
OPTIONAL_DEPENDS
OTBShark
TEST_DEPENDS
OTBBoost
OTBTestKernel
OTBImageIO
......
......@@ -32,6 +32,10 @@ otbKMeansImageClassificationFilterNew.cxx
otbMachineLearningModelTemplates.cxx
)
if(OTB_USE_SHARK)
set(OTBLearningBaseTests ${OTBLearningBaseTests} otbSharkUtilsTests.cxx)
endif()
add_executable(otbLearningBaseTestDriver ${OTBLearningBaseTests})
target_link_libraries(otbLearningBaseTestDriver ${OTBLearningBase-Test_LIBRARIES})
otb_module_target_label(otbLearningBaseTestDriver)
......@@ -68,3 +72,7 @@ otb_add_test(NAME leTuDecisionTreeNew COMMAND otbLearningBaseTestDriver
otb_add_test(NAME leTuKMeansImageClassificationFilterNew COMMAND otbLearningBaseTestDriver
otbKMeansImageClassificationFilterNew)
if(OTB_USE_SHARK)
otb_add_test(NAME leTuSharkNormalizeLabels COMMAND otbLearningBaseTestDriver
otbSharkNormalizeLabels)
endif()
......@@ -29,4 +29,7 @@ void RegisterTests()
REGISTER_TEST(otbSEMClassifierNew);
REGISTER_TEST(otbDecisionTreeNew);
REGISTER_TEST(otbKMeansImageClassificationFilterNew);
#ifdef OTB_USE_SHARK
REGISTER_TEST(otbSharkNormalizeLabels);
#endif
}
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "itkMacro.h"
#include "otbSharkUtils.h"
int otbSharkNormalizeLabels(int itkNotUsed(argc), char* itkNotUsed(argv) [])
{
std::vector<unsigned int> inLabels = {2, 2, 3, 20, 1};
std::vector<unsigned int> expectedDictionary = {2, 3, 20, 1};
std::vector<unsigned int> expectedLabels = {0, 0, 1, 2, 3};
auto newLabels = inLabels;
std::vector<unsigned int> labelDict;
otb::Shark::NormalizeLabelsAndGetDictionary(newLabels, labelDict);
if(newLabels != expectedLabels)
{
std::cout << "Wrong new labels\n";
for(size_t i = 0; i<newLabels.size(); ++i)
std::cout << "Got " << newLabels[i] << " expected " << expectedLabels[i] << '\n';
return EXIT_FAILURE;
}
if(labelDict != expectedDictionary)
{
std::cout << "Wrong dictionary\n";
for(size_t i = 0; i<labelDict.size(); ++i)
std::cout << "Got " << labelDict[i] << " expected " << expectedDictionary[i] << '\n';
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}
......@@ -233,7 +233,7 @@ SamplingRateCalculator
std::string::size_type pos5 = line.find_first_not_of(" \t", parts[2].begin() - line.begin());
std::string::size_type pos6 = line.find_last_not_of(" \t", parts[2].end() - line.begin() -1);
std::string::size_type pos7 = line.find_first_not_of(" \t", parts[3].begin() - line.begin());
std::string::size_type pos8 = line.find_last_not_of(" \t", parts[3].end() - line.begin() -1);
std::string::size_type pos8 = line.find_last_not_of(" \t\r", parts[3].end() - line.begin() -1);
if (pos2 != std::string::npos && pos1 <= pos2 &&
pos4 != std::string::npos && pos3 <= pos4 &&
pos6 != std::string::npos && pos5 <= pos6 &&
......@@ -336,7 +336,7 @@ SamplingRateCalculator
std::string::size_type pos1 = line.find_first_not_of(" \t", parts[0].begin() - line.begin());
std::string::size_type pos2 = line.find_last_not_of(" \t", parts[0].end() - line.begin() -1);
std::string::size_type pos3 = line.find_first_not_of(" \t", parts[1].begin() - line.begin());
std::string::size_type pos4 = line.find_last_not_of(" \t", parts[1].end() - line.begin() -1);
std::string::size_type pos4 = line.find_last_not_of(" \t\r", parts[1].end() - line.begin() -1);
if (pos2 != std::string::npos && pos1 <= pos2 &&
pos4 != std::string::npos && pos3 <= pos4)
{
......
......@@ -33,7 +33,10 @@
#pragma GCC diagnostic ignored "-Wsign-compare"
#pragma GCC diagnostic ignored "-Wcast-align"
#pragma GCC diagnostic ignored "-Wunknown-pragmas"
#pragma GCC diagnostic ignored "-Wheader-guard"
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
#endif
#include <shark/Models/Classifier.h>
#include "otb_shark.h"
#include "shark/Algorithms/Trainers/RFTrainer.h"
#if defined(__GNUC__) || defined(__clang__)
......@@ -134,6 +137,10 @@ public:
/** If true, margin confidence value will be computed */
itkSetMacro(ComputeMargin, bool);
/** If true, class labels will be normalised in [0 ... nbClasses] */
itkGetMacro(NormalizeClassLabels, bool);
itkSetMacro(NormalizeClassLabels, bool);
protected:
/** Constructor */
SharkRandomForestsMachineLearningModel();
......@@ -154,8 +161,10 @@ private:
SharkRandomForestsMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
shark::RFClassifier m_RFModel;
shark::RFTrainer m_RFTrainer;
shark::RFClassifier<unsigned int> m_RFModel;
shark::RFTrainer<unsigned int> m_RFTrainer;
std::vector<unsigned int> m_ClassDictionary;
bool m_NormalizeClassLabels;
unsigned int m_NumberOfTrees;
unsigned int m_MTry;
......
......@@ -32,7 +32,6 @@
#pragma GCC diagnostic ignored "-Woverloaded-virtual"
#pragma GCC diagnostic ignored "-Wignored-qualifiers"
#endif
#include <shark/Models/Converter.h>
#if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop
#endif
......@@ -52,6 +51,7 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
this->m_ConfidenceIndex = true;
this->m_IsRegressionSupported = false;
this->m_IsDoPredictBatchMultiThreaded = true;
this->m_NormalizeClassLabels = true;
}
......@@ -76,13 +76,17 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
Shark::ListSampleToSharkVector(this->GetTargetListSample(), class_labels);
if(m_NormalizeClassLabels)
{
Shark::NormalizeLabelsAndGetDictionary(class_labels, m_ClassDictionary);
}
shark::ClassificationDataset TrainSamples = shark::createLabeledDataFromRange(features,class_labels);
//Set parameters
m_RFTrainer.setMTry(m_MTry);
m_RFTrainer.setNTrees(m_NumberOfTrees);
m_RFTrainer.setNodeSize(m_NodeSize);
m_RFTrainer.setOOBratio(m_OobRatio);
// m_RFTrainer.setOOBratio(m_OobRatio);
m_RFTrainer.train(m_RFModel, TrainSamples);
}
......@@ -125,15 +129,20 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
if (quality != ITK_NULLPTR)
{
shark::RealVector probas = m_RFModel(samples);
shark::RealVector probas = m_RFModel.decisionFunction()(samples);
(*quality) = ComputeConfidence(probas, m_ComputeMargin);
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
unsigned int res;
amc.eval(samples, res);
unsigned int res{0};
m_RFModel.eval(samples, res);
TargetSampleType target;
target[0] = static_cast<TOutputValue>(res);
if(m_NormalizeClassLabels)
{
target[0] = m_ClassDictionary[static_cast<TOutputValue>(res)];
}
else
{
target[0] = static_cast<TOutputValue>(res);
}
return target;
}
......@@ -157,13 +166,13 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange(features);
#ifdef _OPENMP
#ifdef _OPENMP
omp_set_num_threads(itk::MultiThreader::GetGlobalDefaultNumberOfThreads());
#endif
#endif
if(quality != ITK_NULLPTR)
{
shark::Data<shark::RealVector> probas = m_RFModel(inputSamples);
shark::Data<shark::RealVector> probas = m_RFModel.decisionFunction()(inputSamples);
unsigned int id = startIndex;
for(shark::RealVector && p : probas.elements())
{
......@@ -175,14 +184,19 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
}
}
shark::ArgMaxConverter<shark::RFClassifier> amc;
amc.decisionFunction() = m_RFModel;
auto prediction = amc(inputSamples);
auto prediction = m_RFModel(inputSamples);
unsigned int id = startIndex;
for(const auto& p : prediction.elements())
{
TargetSampleType target;
target[0] = static_cast<TOutputValue>(p);
if(m_NormalizeClassLabels)
{
target[0] = m_ClassDictionary[static_cast<TOutputValue>(p)];
}
else
{
target[0] = static_cast<TOutputValue>(p);
}
targets->SetMeasurementVector(id,target);
++id;
}
......@@ -199,7 +213,18 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
itkExceptionMacro(<< "Error opening " << filename.c_str() );
}
// Add comment with model file name
ofs << "#" << m_RFModel.name() << std::endl;
ofs << "#" << m_RFModel.name();
if(m_NormalizeClassLabels) ofs << " with_dictionary";
ofs << std::endl;
if(m_NormalizeClassLabels)
{
ofs << m_ClassDictionary.size() << " ";
for(const auto& l : m_ClassDictionary)
{
ofs << l << " ";
}
ofs << std::endl;
}
shark::TextOutArchive oa(ofs);
m_RFModel.save(oa,0);
}
......@@ -219,6 +244,10 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
{
if( line.find( m_RFModel.name() ) == std::string::npos )
itkExceptionMacro( "The model file : " + filename + " cannot be read." );
if( line.find( "with_dictionary" ) == std::string::npos )
{
m_NormalizeClassLabels=false;
}
}
else
{
......@@ -226,6 +255,18 @@ SharkRandomForestsMachineLearningModel<TInputValue,TOutputValue>
ifs.clear();
ifs.seekg( 0, std::ios::beg );
}
if(m_NormalizeClassLabels)
{
size_t nbLabels{0};
ifs >> nbLabels;
m_ClassDictionary.resize(nbLabels);
for(size_t i=0; i<nbLabels; ++i)
{