Skip to content
Snippets Groups Projects
Commit 1b8af99c authored by Cédric Traizet's avatar Cédric Traizet
Browse files

Autoencoder now outputs the encoder matrix and the hidden bias if asked

parent 9cb87753
No related branches found
No related tags found
No related merge requests found
...@@ -61,6 +61,9 @@ public: ...@@ -61,6 +61,9 @@ public:
itkGetMacro(WriteLearningCurve,bool); itkGetMacro(WriteLearningCurve,bool);
itkSetMacro(WriteLearningCurve,bool); itkSetMacro(WriteLearningCurve,bool);
itkSetMacro(WriteWeights, bool);
itkGetMacro(WriteWeights, bool);
itkGetMacro(LearningCurveFileName,std::string); itkGetMacro(LearningCurveFileName,std::string);
itkSetMacro(LearningCurveFileName,std::string); itkSetMacro(LearningCurveFileName,std::string);
...@@ -105,7 +108,7 @@ private: ...@@ -105,7 +108,7 @@ private:
bool m_WriteLearningCurve; // Flag for writting the learning curve into a txt file bool m_WriteLearningCurve; // Flag for writting the learning curve into a txt file
std::string m_LearningCurveFileName; // Name of the output learning curve printed after training std::string m_LearningCurveFileName; // Name of the output learning curve printed after training
bool m_WriteWeights;
}; };
} // end namespace otb } // end namespace otb
......
...@@ -208,6 +208,22 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Save(const std::string & fil ...@@ -208,6 +208,22 @@ void AutoencoderModel<TInputValue,AutoencoderType>::Save(const std::string & fil
//m_net.write(oa); //m_net.write(oa);
oa << m_net; oa << m_net;
ofs.close(); ofs.close();
if (this->m_WriteWeights == true) // output the map vectors in a txt file
{
std::ofstream otxt(filename+".txt");
for (unsigned int i = 0 ; i < m_NumberOfHiddenNeurons.Size(); ++i)
{
otxt << m_net[i].encoderMatrix() << std::endl;
otxt << m_net[i].hiddenBias() << std::endl;
}
otxt.close();
}
} }
template <class TInputValue, class AutoencoderType> template <class TInputValue, class AutoencoderType>
......
...@@ -42,6 +42,9 @@ public: ...@@ -42,6 +42,9 @@ public:
*/ */
itkSetMacro(Do_resize_flag,bool); itkSetMacro(Do_resize_flag,bool);
itkSetMacro(WriteEigenvectors, bool);
itkGetMacro(WriteEigenvectors, bool);
bool CanReadFile(const std::string & filename); bool CanReadFile(const std::string & filename);
bool CanWriteFile(const std::string & filename); bool CanWriteFile(const std::string & filename);
...@@ -66,6 +69,7 @@ private: ...@@ -66,6 +69,7 @@ private:
shark::PCA m_pca; shark::PCA m_pca;
//unsigned int m_Dimension; //unsigned int m_Dimension;
bool m_Do_resize_flag; bool m_Do_resize_flag;
bool m_WriteEigenvectors;
}; };
} // end namespace otb } // end namespace otb
......
...@@ -152,6 +152,7 @@ void cbLearningApplicationBaseDR<TInputValue,TOutputValue> ...@@ -152,6 +152,7 @@ void cbLearningApplicationBaseDR<TInputValue,TOutputValue>
dimredTrainer->SetRho(rho); dimredTrainer->SetRho(rho);
dimredTrainer->SetBeta(beta); dimredTrainer->SetBeta(beta);
dimredTrainer->SetWriteWeights(true);
if (HasValue("model.autoencoder.learningcurve") && IsParameterEnabled("model.autoencoder.learningcurve")) if (HasValue("model.autoencoder.learningcurve") && IsParameterEnabled("model.autoencoder.learningcurve"))
{ {
std::cout << "yo" << std::endl; std::cout << "yo" << std::endl;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment