Skip to content
Snippets Groups Projects
PCAModel.txx 4.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • 
    #include <fstream>
    #include <shark/Data/Dataset.h>
    #include "itkMacro.h"
    #include "otbSharkUtils.h"
    //include train function
    #include <shark/ObjectiveFunctions/ErrorFunction.h>
    #include <shark/Algorithms/GradientDescent/Rprop.h>// the RProp optimization algorithm
    #include <shark/ObjectiveFunctions/Loss/SquaredLoss.h> // squared loss used for regression
    #include <shark/ObjectiveFunctions/Regularizer.h> //L2 regulariziation
    
    namespace otb
    {
    
    
    template <class TInputValue>
    PCAModel<TInputValue>::PCAModel()
    {
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	this->m_IsDoPredictBatchMultiThreaded = true;
    
    }
    
    
    template <class TInputValue>
    
    {
    }
    
    
    template <class TInputValue>
    void PCAModel<TInputValue>::Train()
    {
    	
    	std::vector<shark::RealVector> features;
    	
    	Shark::ListSampleToSharkVector(this->GetInputListSample(), features);
    	
    	shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange( features );
    
    	std::cout << m_encoder.matrix() << std::endl;
    
    	
    }
    
    
    template <class TInputValue>
    bool PCAModel<TInputValue>::CanReadFile(const std::string & filename)
    {
    	try
    	{
    		this->Load(filename);
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		m_encoder.name();
    
    	}
    	catch(...)
    	{
    	return false;
    	}
    	return true;
    }
    
    
    template <class TInputValue>
    bool PCAModel<TInputValue>::CanWriteFile(const std::string & filename)
    {
    	return true;
    }
    
    template <class TInputValue>
    void PCAModel<TInputValue>::Save(const std::string & filename, const std::string & name)
    {
    	std::ofstream ofs(filename);
    
    	//ofs << m_encoder.name() << std::endl; //first line
    	ofs << "pca" << std::endl; //first line
    
    	boost::archive::polymorphic_text_oarchive oa(ofs);
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	m_encoder.write(oa);
    
    	ofs.close();
    }
    
    template <class TInputValue>
    void PCAModel<TInputValue>::Load(const std::string & filename, const std::string & name)
    {
    	std::ifstream ifs(filename);
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	char encoder[256];
    	ifs.getline(encoder,256); 
    	std::string encoderstr(encoder);
    
    	//if (encoderstr != m_encoder.name()){
    	if (encoderstr != "pca"){
    
    		itkExceptionMacro(<< "Error opening " << filename.c_str() );
        }
    	boost::archive::polymorphic_text_iarchive ia(ifs);
    
    	ifs.close();
    
    	if (this->m_Dimension ==0)
    	{
    		this->m_Dimension = m_encoder.outputSize();
    	}
    	else
    	{
    		std::cout << "yo" << std::endl;
    	}
    	
    
    	eigenvectors.resize(this->m_Dimension,m_encoder.inputSize());
    
    	m_encoder.setStructure(eigenvectors, m_encoder.offset() );
    
    	std::cout << m_encoder.matrix() << "end" << std::endl;
    
    	//this->m_Size = m_NumberOfHiddenNeurons;
    }
    
    
    template <class TInputValue>
    typename PCAModel<TInputValue>::TargetSampleType
    
    PCAModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType * quality) const
    
    {  
    	shark::RealVector samples(value.Size());
    	for(size_t i = 0; i < value.Size();i++)
        {
    
        std::vector<shark::RealVector> features;
        features.push_back(samples);
       
        shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
         
    	data = m_encoder(data);
    
        TargetSampleType target;
    
    	}
    	return target;
    }
    
    
    template <class TInputValue>
    void PCAModel<TInputValue>
    
    ::DoPredictBatch(const InputListSampleType *input, const unsigned int & startIndex, const unsigned int & size, TargetListSampleType * targets, ConfidenceListSampleType * quality) const
    
    {
    	
    	std::vector<shark::RealVector> features;
    	Shark::ListSampleRangeToSharkVector(input, features,startIndex,size);
    	shark::Data<shark::RealVector> data = shark::createDataFromRange(features);
    	TargetSampleType target;
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	data = m_encoder(data);
    
    	unsigned int id = startIndex;
    
    	for(const auto& p : data.elements()){
    		
    
    			target[a]=p[a];
    
    		
    			//target.SetElement(a,p[a]);
    			
    			
    		}
    		//std::cout << p << std::endl;
    		targets->SetMeasurementVector(id,target);
    		++id;
    		
        }
    
    }
    
    
    } // namespace otb
    #endif