Skip to content
Snippets Groups Projects
Commit 112206fe authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: update AE tests with new Shark API

parent 5a8c649e
No related branches found
No related tags found
No related merge requests found
...@@ -44,7 +44,7 @@ ...@@ -44,7 +44,7 @@
#include <shark/Algorithms/StoppingCriteria/MaxIterations.h> //A simple stopping criterion that stops after a fixed number of iterations #include <shark/Algorithms/StoppingCriteria/MaxIterations.h> //A simple stopping criterion that stops after a fixed number of iterations
#include <shark/Algorithms/StoppingCriteria/TrainingProgress.h> //Stops when the algorithm seems to converge, Tracks the progress of the training error over a period of time #include <shark/Algorithms/StoppingCriteria/TrainingProgress.h> //Stops when the algorithm seems to converge, Tracks the progress of the training error over a period of time
#include <shark/Algorithms/GradientDescent/SteepestDescent.h> #include <shark/Algorithms/GradientDescent/Adam.h>
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
#endif #endif
...@@ -167,12 +167,12 @@ AutoencoderModel<TInputValue,NeuronType> ...@@ -167,12 +167,12 @@ AutoencoderModel<TInputValue,NeuronType>
shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples);//labels identical to inputs shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples);//labels identical to inputs
shark::SquaredLoss<shark::RealVector> loss; shark::SquaredLoss<shark::RealVector> loss;
//~ shark::ErrorFunction error(trainSet, &model, &loss); //~ shark::ErrorFunction error(trainSet, &model, &loss);
shark::ErrorFunction error(trainSet, &net, &loss); shark::ErrorFunction<> error(trainSet, &net, &loss);
shark::TwoNormRegularizer regularizer(error.numberOfVariables()); shark::TwoNormRegularizer<> regularizer(error.numberOfVariables());
error.setRegularizer(m_Regularization[layer_index],&regularizer); error.setRegularizer(m_Regularization[layer_index],&regularizer);
shark::IRpropPlusFull optimizer; shark::Adam<> optimizer;
error.init(); error.init();
optimizer.init(error); optimizer.init(error);
...@@ -221,11 +221,11 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneSparseLayer( ...@@ -221,11 +221,11 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneSparseLayer(
shark::SquaredLoss<shark::RealVector> loss; shark::SquaredLoss<shark::RealVector> loss;
//~ shark::SparseAutoencoderError error(trainSet,&net, &loss, m_Rho[layer_index], m_Beta[layer_index]); //~ shark::SparseAutoencoderError error(trainSet,&net, &loss, m_Rho[layer_index], m_Beta[layer_index]);
// SparseAutoencoderError doesn't exist anymore, for now use a plain ErrorFunction // SparseAutoencoderError doesn't exist anymore, for now use a plain ErrorFunction
shark::ErrorFunction error(trainSet, &net, &loss); shark::ErrorFunction<> error(trainSet, &net, &loss);
shark::TwoNormRegularizer regularizer(error.numberOfVariables()); shark::TwoNormRegularizer<> regularizer(error.numberOfVariables());
error.setRegularizer(m_Regularization[layer_index],&regularizer); error.setRegularizer(m_Regularization[layer_index],&regularizer);
shark::IRpropPlusFull optimizer; shark::Adam<> optimizer;
error.init(); error.init();
optimizer.init(error); optimizer.init(error);
...@@ -269,11 +269,11 @@ AutoencoderModel<TInputValue,NeuronType> ...@@ -269,11 +269,11 @@ AutoencoderModel<TInputValue,NeuronType>
shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples); shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples);
shark::SquaredLoss<shark::RealVector> loss; shark::SquaredLoss<shark::RealVector> loss;
shark::ErrorFunction error(trainSet, &net, &loss); shark::ErrorFunction<> error(trainSet, &net, &loss);
shark::TwoNormRegularizer regularizer(error.numberOfVariables()); shark::TwoNormRegularizer<> regularizer(error.numberOfVariables());
error.setRegularizer(m_Regularization[0],&regularizer); error.setRegularizer(m_Regularization[0],&regularizer);
shark::IRpropPlusFull optimizer; shark::Adam<> optimizer;
error.init(); error.init();
optimizer.init(error); optimizer.init(error);
otbMsgDevMacro(<<"Error before training : " << optimizer.solution().value); otbMsgDevMacro(<<"Error before training : " << optimizer.solution().value);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment