From 7b78f3d759965867c18ba713b9698d463c0ae51c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <traizetc@cesbio.cnes.fr>
Date: Thu, 27 Jul 2017 10:45:03 +0200
Subject: [PATCH] the weight initial value during training can now be set with
 a parameter

---
 app/cbDimensionalityReduction.cxx        | 2 +-
 app/cbDimensionalityReductionTrainer.cxx | 2 +-
 app/cbDimensionalityReductionVector.cxx  | 2 +-
 include/AutoencoderModel.h               | 4 ++++
 include/AutoencoderModel.txx             | 4 ++--
 include/cbTrainAutoencoder.txx           | 7 +++++++
 6 files changed, 16 insertions(+), 5 deletions(-)

diff --git a/app/cbDimensionalityReduction.cxx b/app/cbDimensionalityReduction.cxx
index 106f609d48..bc9a5754fe 100644
--- a/app/cbDimensionalityReduction.cxx
+++ b/app/cbDimensionalityReduction.cxx
@@ -222,7 +222,7 @@ private:
         }
         
       // Rescale vector image
-      m_Rescaler->SetScale(stddevMeasurementVector);
+      m_Rescaler->SetScale(stddevMeasurementVector*3);
       m_Rescaler->SetShift(meanMeasurementVector);
       m_Rescaler->SetInput(inImage);
 
diff --git a/app/cbDimensionalityReductionTrainer.cxx b/app/cbDimensionalityReductionTrainer.cxx
index 4cb042427b..55232ef871 100644
--- a/app/cbDimensionalityReductionTrainer.cxx
+++ b/app/cbDimensionalityReductionTrainer.cxx
@@ -124,7 +124,7 @@ private:
 		ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
 		trainingShiftScaleFilter->SetInput(input);
 		trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
-		trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
+		trainingShiftScaleFilter->SetScales(stddevMeasurementVector*3);
 		trainingShiftScaleFilter->Update();
 
 		ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput();
diff --git a/app/cbDimensionalityReductionVector.cxx b/app/cbDimensionalityReductionVector.cxx
index 12e1307ad1..cf2caed548 100644
--- a/app/cbDimensionalityReductionVector.cxx
+++ b/app/cbDimensionalityReductionVector.cxx
@@ -223,7 +223,7 @@ class CbDimensionalityReductionVector : public Application
 			ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
 			trainingShiftScaleFilter->SetInput(input);
 			trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
-			trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
+			trainingShiftScaleFilter->SetScales(stddevMeasurementVector*3);
 			trainingShiftScaleFilter->Update();
 			otbAppLogINFO("mean used: " << meanMeasurementVector);
 			otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
diff --git a/include/AutoencoderModel.h b/include/AutoencoderModel.h
index 5dbca52e27..9e63558a2c 100644
--- a/include/AutoencoderModel.h
+++ b/include/AutoencoderModel.h
@@ -54,6 +54,9 @@ public:
 	itkGetMacro(Epsilon,double);
 	itkSetMacro(Epsilon,double);
 
+	itkGetMacro(InitFactor,double);
+	itkSetMacro(InitFactor,double);
+
 	itkGetMacro(Regularization,itk::Array<double>);
 	itkSetMacro(Regularization,itk::Array<double>);
 
@@ -113,6 +116,7 @@ private:
 	itk::Array<double> m_Noise;  // probability for an input to be set to 0 (denosing autoencoder)
 	itk::Array<double> m_Rho; // Sparsity parameter
 	itk::Array<double> m_Beta; // Sparsity regularization parameter
+	double m_InitFactor; // Weight initialization factor (the weights are intialized at m_initfactor/sqrt(inputDimension)  )
 	
 	bool m_WriteLearningCurve; // Flag for writting the learning curve into a txt file
 	std::string m_LearningCurveFileName; // Name of the output learning curve printed after training
diff --git a/include/AutoencoderModel.txx b/include/AutoencoderModel.txx
index 668bb7d6ad..f45d055224 100644
--- a/include/AutoencoderModel.txx
+++ b/include/AutoencoderModel.txx
@@ -151,7 +151,7 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneLayer(shark::AbstractStop
 
 	std::size_t inputs = dataDimension(samples);
 	net.setStructure(inputs, nbneuron);
-	initRandomUniform(net,-0.1*std::sqrt(1.0/inputs),0.1*std::sqrt(1.0/inputs));
+	initRandomUniform(net,-m_InitFactor*std::sqrt(1.0/inputs),m_InitFactor*std::sqrt(1.0/inputs));
 	//initRandomUniform(net,-1,1);
 	shark::ImpulseNoiseModel noise(noise_strength,0.0); //set an input pixel with probability m_Noise to 0
 	shark::ConcatenatedModel<shark::RealVector,shark::RealVector> model = noise>> net;
@@ -200,7 +200,7 @@ void AutoencoderModel<TInputValue,NeuronType>::TrainOneSparseLayer(shark::Abstra
 
 	std::size_t inputs = dataDimension(samples);
 	net.setStructure(inputs, nbneuron);
-	initRandomUniform(net,-0.1*std::sqrt(1.0/inputs),0.1*std::sqrt(1.0/inputs));
+	initRandomUniform(net,-m_InitFactor*std::sqrt(1.0/inputs),m_InitFactor*std::sqrt(1.0/inputs));
 	//initRandomUniform(net,-1,1);
 	shark::LabeledData<shark::RealVector,shark::RealVector> trainSet(samples,samples);//labels identical to inputs
 	shark::SquaredLoss<shark::RealVector> loss;
diff --git a/include/cbTrainAutoencoder.txx b/include/cbTrainAutoencoder.txx
index 57d6b86144..3272376657 100644
--- a/include/cbTrainAutoencoder.txx
+++ b/include/cbTrainAutoencoder.txx
@@ -52,6 +52,12 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue>
     " ");
   
   
+  AddParameter(ParameterType_Float, "model.autoencoder.initfactor",
+               " ");
+  SetParameterFloat("model.autoencoder.initfactor",1, false);
+  SetParameterDescription(
+    "model.autoencoder.initfactor", "parameter that control the weight initialization of the autoencoder");
+  
    //Number Of Hidden Neurons
   AddParameter(ParameterType_StringList ,  "model.autoencoder.nbneuron",   "Size");
   /*AddParameter(ParameterType_Int, "model.autoencoder.nbneuron",
@@ -146,6 +152,7 @@ void cbLearningApplicationBaseDR<TInputValue,TOutputValue>
 		dimredTrainer->SetNumberOfHiddenNeurons(nb_neuron);
 		dimredTrainer->SetNumberOfIterations(GetParameterInt("model.autoencoder.nbiter"));
 		dimredTrainer->SetEpsilon(GetParameterFloat("model.autoencoder.epsilon"));
+		dimredTrainer->SetInitFactor(GetParameterFloat("model.autoencoder.initfactor"));
 		dimredTrainer->SetRegularization(regularization);
 		dimredTrainer->SetNoise(noise);
 		dimredTrainer->SetRho(rho);
-- 
GitLab