From 112206feb2d2d05fd4de276dad4c71f7ae5c80a0 Mon Sep 17 00:00:00 2001 From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr> Date: Mon, 26 Mar 2018 09:20:05 +0200 Subject: [PATCH] ENH: update AE tests with new Shark API --- .../include/otbAutoencoderModel.txx | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.txx b/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.txx index 254143a501..e5a26e9ee3 100644 --- a/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.txx +++ b/Modules/Learning/DimensionalityReductionLearning/include/otbAutoencoderModel.txx @@ -44,7 +44,7 @@ #include <shark/Algorithms/StoppingCriteria/MaxIterations.h> //A simple stopping criterion that stops after a fixed number of iterations #include <shark/Algorithms/StoppingCriteria/TrainingProgress.h> //Stops when the algorithm seems to converge, Tracks the progress of the training error over a period of time -#include <shark/Algorithms/GradientDescent/SteepestDescent.h> +#include <shark/Algorithms/GradientDescent/Adam.h> #if defined(__GNUC__) || defined(__clang__) #pragma GCC diagnostic pop #endif @@ -167,12 +167,12 @@ AutoencoderModel<TInputValue,NeuronType> shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples);//labels identical to inputs shark::SquaredLoss<shark::RealVector> loss; //~ shark::ErrorFunction error(trainSet, &model, &loss); - shark::ErrorFunction error(trainSet, &net, &loss); + shark::ErrorFunction<> error(trainSet, &net, &loss); - shark::TwoNormRegularizer regularizer(error.numberOfVariables()); + shark::TwoNormRegularizer<> regularizer(error.numberOfVariables()); error.setRegularizer(m_Regularization[layer_index],®ularizer); - shark::IRpropPlusFull optimizer; + shark::Adam<> optimizer; error.init(); optimizer.init(error); @@ -221,11 +221,11 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneSparseLayer( shark::SquaredLoss<shark::RealVector> loss; //~ shark::SparseAutoencoderError error(trainSet,&net, &loss, m_Rho[layer_index], m_Beta[layer_index]); // SparseAutoencoderError doesn't exist anymore, for now use a plain ErrorFunction - shark::ErrorFunction error(trainSet, &net, &loss); + shark::ErrorFunction<> error(trainSet, &net, &loss); - shark::TwoNormRegularizer regularizer(error.numberOfVariables()); + shark::TwoNormRegularizer<> regularizer(error.numberOfVariables()); error.setRegularizer(m_Regularization[layer_index],®ularizer); - shark::IRpropPlusFull optimizer; + shark::Adam<> optimizer; error.init(); optimizer.init(error); @@ -269,11 +269,11 @@ AutoencoderModel<TInputValue,NeuronType> shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples); shark::SquaredLoss<shark::RealVector> loss; - shark::ErrorFunction error(trainSet, &net, &loss); - shark::TwoNormRegularizer regularizer(error.numberOfVariables()); + shark::ErrorFunction<> error(trainSet, &net, &loss); + shark::TwoNormRegularizer<> regularizer(error.numberOfVariables()); error.setRegularizer(m_Regularization[0],®ularizer); - shark::IRpropPlusFull optimizer; + shark::Adam<> optimizer; error.init(); optimizer.init(error); otbMsgDevMacro(<<"Error before training : " << optimizer.solution().value); -- GitLab