Commit 7b78f3d7 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

the weight initial value during training can now be set with a parameter

parent e9635127
......@@ -222,7 +222,7 @@ private:
}
// Rescale vector image
m_Rescaler->SetScale(stddevMeasurementVector);
m_Rescaler->SetScale(stddevMeasurementVector*3);
m_Rescaler->SetShift(meanMeasurementVector);
m_Rescaler->SetInput(inImage);
......
......@@ -124,7 +124,7 @@ private:
ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
trainingShiftScaleFilter->SetInput(input);
trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
trainingShiftScaleFilter->SetScales(stddevMeasurementVector*3);
trainingShiftScaleFilter->Update();
ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput();
......
......@@ -223,7 +223,7 @@ class CbDimensionalityReductionVector : public Application
ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
trainingShiftScaleFilter->SetInput(input);
trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
trainingShiftScaleFilter->SetScales(stddevMeasurementVector*3);
trainingShiftScaleFilter->Update();
otbAppLogINFO("mean used: " << meanMeasurementVector);
otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
......
......@@ -54,6 +54,9 @@ public:
itkGetMacro(Epsilon,double);
itkSetMacro(Epsilon,double);
itkGetMacro(InitFactor,double);
itkSetMacro(InitFactor,double);
itkGetMacro(Regularization,itk::Array<double>);
itkSetMacro(Regularization,itk::Array<double>);
......@@ -113,6 +116,7 @@ private:
itk::Array<double> m_Noise; // probability for an input to be set to 0 (denosing autoencoder)
itk::Array<double> m_Rho; // Sparsity parameter
itk::Array<double> m_Beta; // Sparsity regularization parameter
double m_InitFactor; // Weight initialization factor (the weights are intialized at m_initfactor/sqrt(inputDimension) )
bool m_WriteLearningCurve; // Flag for writting the learning curve into a txt file
std::string m_LearningCurveFileName; // Name of the output learning curve printed after training
......
......@@ -151,7 +151,7 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneLayer(shark::AbstractStop
std::size_t inputs = dataDimension(samples);
net.setStructure(inputs, nbneuron);
initRandomUniform(net,-0.1*std::sqrt(1.0/inputs),0.1*std::sqrt(1.0/inputs));
initRandomUniform(net,-m_InitFactor*std::sqrt(1.0/inputs),m_InitFactor*std::sqrt(1.0/inputs));
//initRandomUniform(net,-1,1);
shark::ImpulseNoiseModel noise(noise_strength,0.0); //set an input pixel with probability m_Noise to 0
shark::ConcatenatedModel<shark::RealVector,shark::RealVector> model = noise>> net;
......@@ -200,7 +200,7 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneSparseLayer(shark::Abstra
std::size_t inputs = dataDimension(samples);
net.setStructure(inputs, nbneuron);
initRandomUniform(net,-0.1*std::sqrt(1.0/inputs),0.1*std::sqrt(1.0/inputs));
initRandomUniform(net,-m_InitFactor*std::sqrt(1.0/inputs),m_InitFactor*std::sqrt(1.0/inputs));
//initRandomUniform(net,-1,1);
shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples);//labels identical to inputs
shark::SquaredLoss<shark::RealVector> loss;
......
......@@ -52,6 +52,12 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue>
" ");
AddParameter(ParameterType_Float, "model.autoencoder.initfactor",
" ");
SetParameterFloat("model.autoencoder.initfactor",1, false);
SetParameterDescription(
"model.autoencoder.initfactor", "parameter that control the weight initialization of the autoencoder");
//Number Of Hidden Neurons
AddParameter(ParameterType_StringList , "model.autoencoder.nbneuron", "Size");
/*AddParameter(ParameterType_Int, "model.autoencoder.nbneuron",
......@@ -146,6 +152,7 @@ void cbLearningApplicationBaseDR<TInputValue,TOutputValue>
dimredTrainer->SetNumberOfHiddenNeurons(nb_neuron);
dimredTrainer->SetNumberOfIterations(GetParameterInt("model.autoencoder.nbiter"));
dimredTrainer->SetEpsilon(GetParameterFloat("model.autoencoder.epsilon"));
dimredTrainer->SetInitFactor(GetParameterFloat("model.autoencoder.initfactor"));
dimredTrainer->SetRegularization(regularization);
dimredTrainer->SetNoise(noise);
dimredTrainer->SetRho(rho);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment