Skip to content
Snippets Groups Projects
otbTrainDimensionalityReductionApplicationBase.txx 3.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • /*=========================================================================
     Program:   ORFEO Toolbox
     Language:  C++
     Date:      $Date$
     Version:   $Revision$
    
    
     Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
     See OTBCopyright.txt for details.
    
    
     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.
    
     =========================================================================*/
    #ifndef cbLearningApplicationBaseDR_txx
    #define cbLearningApplicationBaseDR_txx
    
    
    #include "otbTrainDimensionalityReductionApplicationBase.h"
    
    
    namespace otb
    {
    namespace Wrapper
    {
    
    template <class TInputValue, class TOutputValue>
    
    TrainDimensionalityReductionApplicationBase<TInputValue,TOutputValue>
    ::TrainDimensionalityReductionApplicationBase() 
    
    TrainDimensionalityReductionApplicationBase<TInputValue,TOutputValue>
    ::~TrainDimensionalityReductionApplicationBase()
    
    {
      ModelFactoryType::CleanFactories();
    }
    
    template <class TInputValue, class TOutputValue>
    void
    
    TrainDimensionalityReductionApplicationBase<TInputValue,TOutputValue>
    
      // main choice parameter that will contain all dimensionality reduction options
    
      AddParameter(ParameterType_Choice, "model", "model to use for the training");
    
      SetParameterDescription("model", "Choice of the dimensionality reduction model to use for the training.");
    
    TrainDimensionalityReductionApplicationBase<TInputValue,TOutputValue>
    
    ::Reduce(typename ListSampleType::Pointer validationListSample,std::string modelPath)
    
    TrainDimensionalityReductionApplicationBase<TInputValue,TOutputValue>
    
    ::Train(typename ListSampleType::Pointer trainingListSample,
            std::string modelPath)
    {
     
    
     // get the name of the chosen machine learning model
     const std::string modelName = GetParameterString("model");
     // call specific train function
    
    	{
    		BeforeTrainSOM(trainingListSample,modelPath);
    	}
    
     if(modelName == "autoencoder")
    
        BeforeTrainAutoencoder(trainingListSample,modelPath);
    #else
        otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
    #endif
    
    		#ifdef OTB_USE_SHARK
    		TrainAutoencoder<TiedAutoencoderModelType>(trainingListSample,modelPath);
    		#else
    		otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
    		#endif
    
    		#ifdef OTB_USE_SHARK
    		TrainPCA(trainingListSample,modelPath);
    		#else
    		otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
    		#endif