diff --git a/app/cbDimensionalityReduction.cxx b/app/cbDimensionalityReduction.cxx index 5a019b3989e26dc5210ed11b2776895cac03f78d..7b1063147b8ee836e5821c604c2315453f3da641 100644 --- a/app/cbDimensionalityReduction.cxx +++ b/app/cbDimensionalityReduction.cxx @@ -28,7 +28,6 @@ #include "otbImageToVectorImageCastFilter.h" #include "DimensionalityReductionModelFactory.h" - namespace otb { namespace Functor @@ -273,7 +272,6 @@ private: ModelPointerType m_Model; RescalerType::Pointer m_Rescaler; OutputRescalerType::Pointer m_OutRescaler; - }; diff --git a/include/AutoencoderModelFactory.h b/include/AutoencoderModelFactory.h index 0b8f538272f6e6ee43b485177ea3013b436e9d45..73d841e9a8cbd512d8fcd036e13bb3794b46a3f3 100644 --- a/include/AutoencoderModelFactory.h +++ b/include/AutoencoderModelFactory.h @@ -49,14 +49,14 @@ private: - +/* template <class TInputValue, class TTargetValue> -class ITK_EXPORT AutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; +using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> ; template <class TInputValue, class TTargetValue> -class ITK_EXPORT TiedAutoencoderModelFactory : public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; - +using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> ; +*/ } //namespace otb diff --git a/include/DimensionalityReductionModelFactory.h b/include/DimensionalityReductionModelFactory.h index eebade611d3455077e6cd3fc06cd5fe4925ab2ed..e7a913cc07d89c379a4939e69203d4358bc50fdb 100644 --- a/include/DimensionalityReductionModelFactory.h +++ b/include/DimensionalityReductionModelFactory.h @@ -50,6 +50,7 @@ public: /** Mode in which the files is intended to be used */ typedef enum { ReadMode, WriteMode } FileModeType; + /** Create the appropriate MachineLearningModel depending on the particulars of the file. */ static DimensionalityReductionModelTypePointer CreateDimensionalityReductionModel(const std::string& path, FileModeType mode); diff --git a/include/DimensionalityReductionModelFactory.txx b/include/DimensionalityReductionModelFactory.txx index 51064244036152bdb3d4992370eb50547f2f1cc8..0c05937d39c7ba448c1f43ec360d10f62fcaea80 100644 --- a/include/DimensionalityReductionModelFactory.txx +++ b/include/DimensionalityReductionModelFactory.txx @@ -32,6 +32,15 @@ namespace otb { + +template <class TInputValue, class TTargetValue> +using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> ; + + +template <class TInputValue, class TTargetValue> +using TiedAutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> ; + + template <class TInputValue, class TOutputValue> typename DimensionalityReductionModel<TInputValue,TOutputValue>::Pointer DimensionalityReductionModelFactory<TInputValue,TOutputValue> @@ -88,6 +97,13 @@ DimensionalityReductionModelFactory<TInputValue,TOutputValue> #ifdef OTB_USE_SHARK + + + // using AutoencoderModelFactory = AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; + + + //using TiedAutoencoderModelFactory = public AutoencoderModelFactoryBase<TInputValue, TTargetValue, shark::TiedAutoencoder< shark::TanhNeuron, shark::LinearNeuron>> {}; + RegisterFactory(PCAModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(AutoencoderModelFactory<TInputValue,TOutputValue>::New()); RegisterFactory(TiedAutoencoderModelFactory<TInputValue,TOutputValue>::New()); diff --git a/include/PCAModel.h b/include/PCAModel.h index 96d0e235c0cddadf90bddd390da83586799b5435..4ea1370a9fa2c23d3c27793cede22876334e9694 100644 --- a/include/PCAModel.h +++ b/include/PCAModel.h @@ -28,7 +28,7 @@ public: typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; itkNewMacro(Self); - itkTypeMacro(AutoencoderModel, DimensionalityReductionModel); + itkTypeMacro(PCAModel, DimensionalityReductionModel); unsigned int GetDimension() {return m_Dimension;}; itkSetMacro(Dimension,unsigned int); diff --git a/include/PCAModel.txx b/include/PCAModel.txx index c7625e661e8bbae0eb54cd81281b96807d5640d1..92fdc3ee4cc47c75488fee156a58842e35a701a6 100644 --- a/include/PCAModel.txx +++ b/include/PCAModel.txx @@ -38,7 +38,6 @@ void PCAModel<TInputValue>::Train() Shark::ListSampleToSharkVector(this->GetInputListSample(), features); shark::Data<shark::RealVector> inputSamples = shark::createDataFromRange( features ); - //m_pca.train(m_encoder,inputSamples); m_pca.setData(inputSamples); m_pca.encoder(m_encoder, m_Dimension); m_pca.decoder(m_decoder, m_Dimension); diff --git a/include/SOMModel.h b/include/SOMModel.h new file mode 100644 index 0000000000000000000000000000000000000000..1a469e44369509cf81ccfe97876839222306897e --- /dev/null +++ b/include/SOMModel.h @@ -0,0 +1,132 @@ +#ifndef SOMModel_h +#define SOMModel_h + +#include "DimensionalityReductionModel.h" +#include "otbSOMMap.h" + +#include "otbSOM.h" + +#include "itkEuclideanDistanceMetric.h" // the distance function + +#include "otbCzihoSOMLearningBehaviorFunctor.h" +#include "otbCzihoSOMNeighborhoodBehaviorFunctor.h" + + + +namespace otb +{ +template <class TInputValue> +class ITK_EXPORT SOMModel: public DimensionalityReductionModel<TInputValue,TInputValue> +{ + +public: + + typedef SOMModel Self; + typedef DimensionalityReductionModel<TInputValue,TInputValue> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + typedef typename Superclass::InputValueType InputValueType; + typedef typename Superclass::InputSampleType InputSampleType; + typedef typename Superclass::InputListSampleType InputListSampleType; + typedef typename InputListSampleType::Pointer ListSamplePointerType; + typedef typename Superclass::TargetValueType TargetValueType; + typedef typename Superclass::TargetSampleType TargetSampleType; + typedef typename Superclass::TargetListSampleType TargetListSampleType; + typedef typename Superclass::ConfidenceValueType ConfidenceValueType; + typedef typename Superclass::ConfidenceSampleType ConfidenceSampleType; + typedef typename Superclass::ConfidenceListSampleType ConfidenceListSampleType; + + typedef SOMMap<itk::VariableLengthVector<TInputValue>,itk::Statistics::EuclideanDistanceMetric<itk::VariableLengthVector<TInputValue>>, 2> MapType; + typedef typename MapType::SizeType SizeType; + + typedef otb::SOM<InputListSampleType, MapType> EstimatorType; + + + typedef Functor::CzihoSOMLearningBehaviorFunctor SOMLearningBehaviorFunctorType; + typedef Functor::CzihoSOMNeighborhoodBehaviorFunctor SOMNeighborhoodBehaviorFunctorType; + + itkNewMacro(Self); + itkTypeMacro(SOMModel, DimensionalityReductionModel); + + /** Accessors */ + itkSetMacro(NumberOfIterations, unsigned int); + itkGetMacro(NumberOfIterations, unsigned int); + itkSetMacro(BetaInit, double); + itkGetMacro(BetaInit, double); + itkSetMacro(BetaEnd, double); + itkGetMacro(BetaEnd, double); + itkSetMacro(MinWeight, InputValueType); + itkGetMacro(MinWeight, InputValueType); + itkSetMacro(MaxWeight, InputValueType); + itkGetMacro(MaxWeight, InputValueType); + itkSetMacro(MapSize, SizeType); + itkGetMacro(MapSize, SizeType); + itkSetMacro(NeighborhoodSizeInit, SizeType); + itkGetMacro(NeighborhoodSizeInit, SizeType); + itkSetMacro(RandomInit, bool); + itkGetMacro(RandomInit, bool); + itkSetMacro(Seed, unsigned int); + itkGetMacro(Seed, unsigned int); + itkGetObjectMacro(ListSample, InputListSampleType); + itkSetObjectMacro(ListSample, InputListSampleType); + + bool CanReadFile(const std::string & filename); + bool CanWriteFile(const std::string & filename); + + void Save(const std::string & filename, const std::string & name="") ITK_OVERRIDE; + void Load(const std::string & filename, const std::string & name="") ITK_OVERRIDE; + + void Train() ITK_OVERRIDE; + //void Dimensionality_reduction() {}; // Dimensionality reduction is done by DoPredict + + unsigned int GetDimension() { return MapType::ImageDimension;}; +protected: + SOMModel(); + ~SOMModel() ITK_OVERRIDE; + + virtual TargetSampleType DoPredict(const InputSampleType& input, ConfidenceValueType *quality=ITK_NULLPTR) const ITK_OVERRIDE; + +private: + typename MapType::Pointer m_SOMMap; + + + /** Map Parameters used for training */ + + SizeType m_MapSize; + /** Number of iterations */ + unsigned int m_NumberOfIterations; + /** Initial learning coefficient */ + double m_BetaInit; + /** Final learning coefficient */ + double m_BetaEnd; + /** Initial neighborhood size */ + SizeType m_NeighborhoodSizeInit; + /** Minimum initial neuron weights */ + InputValueType m_MinWeight; + /** Maximum initial neuron weights */ + InputValueType m_MaxWeight; + /** Random initialization bool */ + bool m_RandomInit; + /** Seed for random initialization */ + unsigned int m_Seed; + /** The input list sample */ + ListSamplePointerType m_ListSample; + /** Behavior of the Learning weightening (link to the beta coefficient) */ + SOMLearningBehaviorFunctorType m_BetaFunctor; + /** Behavior of the Neighborhood extent */ + SOMNeighborhoodBehaviorFunctorType m_NeighborhoodSizeFunctor; + +}; + + +} // end namespace otb + + +#ifndef OTB_MANUAL_INSTANTIATION +#include "SOMModel.txx" +#endif + + +#endif + diff --git a/include/SOMModel.txx b/include/SOMModel.txx new file mode 100644 index 0000000000000000000000000000000000000000..78df8415c8470d3fb365030606bb1a60e45f0e8f --- /dev/null +++ b/include/SOMModel.txx @@ -0,0 +1,101 @@ + +#ifndef SOMModel_txx +#define SOMModel_txx + +#include "otbImageFileReader.h" +#include "otbImageFileWriter.h" + +#include "itkMacro.h" + +namespace otb +{ + + +template <class TInputValue> +SOMModel<TInputValue>::SOMModel() +{ +} + + +template <class TInputValue> +SOMModel<TInputValue>::~SOMModel() +{ +} + + +template <class TInputValue> +void SOMModel<TInputValue>::Train() +{ + + typename EstimatorType::Pointer estimator = EstimatorType::New(); + + estimator->SetListSample(m_ListSample); + estimator->SetMapSize(m_MapSize); + estimator->SetNeighborhoodSizeInit(m_NeighborhoodSizeInit); + estimator->SetNumberOfIterations(m_NumberOfIterations); + estimator->SetBetaInit(m_BetaInit); + estimator->SetBetaEnd(m_BetaEnd); + estimator->SetMaxWeight(m_MaxWeight); + //AddProcess(estimator,"Learning"); + std::cout << "list = " << m_ListSample << std::endl; + std::cout << "size = " << m_MapSize << std::endl; + std::cout << "neigsize = " << m_NeighborhoodSizeInit << std::endl; + std::cout << "n iter = " << m_NumberOfIterations << std::endl; + std::cout << "bi = " << m_BetaInit << std::endl; + std::cout << "be = " << m_BetaEnd << std::endl; + std::cout << "mw = " << m_MaxWeight << std::endl; + + estimator->Update(); + + m_SOMMap = estimator->GetOutput(); +} + + +template <class TInputValue> +bool SOMModel<TInputValue>::CanReadFile(const std::string & filename) +{ + return true; +} + + +template <class TInputValue> +bool SOMModel<TInputValue>::CanWriteFile(const std::string & filename) +{ + return true; +} + +template <class TInputValue> +void SOMModel<TInputValue>::Save(const std::string & filename, const std::string & name) +{ + std::cout << m_SOMMap->GetNumberOfComponentsPerPixel() << std::endl; + +//Ecriture + auto kwl = m_SOMMap->GetImageKeywordlist(); + //kwl.AddKey("MachineLearningModelType", "SOM"); + //m_SOMMap->SetImageKeywordList(kwl); + auto writer = otb::ImageFileWriter<MapType>::New(); + writer->SetInput(m_SOMMap); + writer->SetFileName(filename); + writer->Update(); + +} + +template <class TInputValue> +void SOMModel<TInputValue>::Load(const std::string & filename, const std::string & name) +{ + auto reader = otb::ImageFileReader<MapType>::New(); + reader->SetFileName(filename); + reader->Update(); + std::cout << reader->GetOutput()->GetImageKeywordlist().GetMetadataByKey("MachineLearningModelType") << '\n'; + m_SOMMap = reader->GetOutput(); +} + + +template <class TInputValue> +typename SOMModel<TInputValue>::TargetSampleType +SOMModel<TInputValue>::DoPredict(const InputSampleType & value, ConfidenceValueType *quality) const +{ +} + +} // namespace otb +#endif diff --git a/include/SOMModelFactory.h b/include/SOMModelFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..56c0b88c6c3d85414d8f0d95b93f64047d2fbd06 --- /dev/null +++ b/include/SOMModelFactory.h @@ -0,0 +1,59 @@ +#ifndef PCAModelFactory_h +#define PCAModelFactory_h + + +#include "itkObjectFactoryBase.h" +#include "itkImageIOBase.h" + +namespace otb +{ + +template <class TInputValue, class TTargetValue> +class ITK_EXPORT PCAModelFactory : public itk::ObjectFactoryBase +{ +public: + /** Standard class typedefs. */ + typedef PCAModelFactory Self; + typedef itk::ObjectFactoryBase Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Class methods used to interface with the registered factories. */ + const char* GetITKSourceVersion(void) const ITK_OVERRIDE; + const char* GetDescription(void) const ITK_OVERRIDE; + + /** Method for class instantiation. */ + itkFactorylessNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(PCAModelFactory, itk::ObjectFactoryBase); + + /** Register one factory of this type */ + static void RegisterOneFactory(void) + { + Pointer PCAFactory = PCAModelFactory::New(); + itk::ObjectFactoryBase::RegisterFactory(PCAFactory); + } + +protected: + PCAModelFactory(); + ~PCAModelFactory() ITK_OVERRIDE; + +private: + PCAModelFactory(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + +}; + + + +} //namespace otb + + +#ifndef OTB_MANUAL_INSTANTIATION +#include "PCAModelFactory.txx" +#endif + +#endif + + diff --git a/include/SOMModelFactory.txx b/include/SOMModelFactory.txx new file mode 100644 index 0000000000000000000000000000000000000000..bfaa4f6f624b12d8351defb1bc52ad9a4d37253e --- /dev/null +++ b/include/SOMModelFactory.txx @@ -0,0 +1,65 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef PCAFactory_txx +#define PCAFactory_txx + + +#include "PCAModelFactory.h" + +#include "itkCreateObjectFunction.h" +#include "PCAModel.h" +//#include <shark/Algorithms/Trainers/PCA.h> +#include "itkVersion.h" + +namespace otb +{ +template <class TInputValue, class TOutputValue> +PCAModelFactory<TInputValue,TOutputValue>::PCAModelFactory() +{ + + std::string classOverride = std::string("DimensionalityReductionModel"); + std::string subclass = std::string("PCAModel"); + + this->RegisterOverride(classOverride.c_str(), + subclass.c_str(), + "Shark PCA ML Model", + 1, + // itk::CreateObjectFunction<AutoencoderModel<TInputValue,TOutputValue> >::New()); + itk::CreateObjectFunction<PCAModel<TInputValue>>::New()); +} + +template <class TInputValue, class TOutputValue> +PCAModelFactory<TInputValue,TOutputValue>::~PCAModelFactory() +{ +} + +template <class TInputValue, class TOutputValue> +const char* PCAModelFactory<TInputValue,TOutputValue>::GetITKSourceVersion(void) const +{ + return ITK_SOURCE_VERSION; +} + +template <class TInputValue, class TOutputValue> +const char* PCAModelFactory<TInputValue,TOutputValue>::GetDescription() const +{ + return "PCA model factory"; +} + +} // end namespace otb + +#endif diff --git a/include/cbLearningApplicationBaseDR.h b/include/cbLearningApplicationBaseDR.h index 7a467e242e39ce709d97df95ce30a19f83de3cf4..9ef37ef948487c3c2f729d9a0b9e973af8de4d95 100644 --- a/include/cbLearningApplicationBaseDR.h +++ b/include/cbLearningApplicationBaseDR.h @@ -14,6 +14,8 @@ //Estimator #include "DimensionalityReductionModelFactory.h" +#include "SOMModel.h" + #ifdef OTB_USE_SHARK #include "AutoencoderModel.h" #include "PCAModel.h" @@ -75,7 +77,6 @@ public: typedef otb::VectorImage<InputValueType> SampleImageType; typedef typename SampleImageType::PixelType PixelType; - // Machine Learning models typedef otb::DimensionalityReductionModelFactory< InputValueType, OutputValueType> ModelFactoryType; typedef typename ModelFactoryType::DimensionalityReductionModelTypePointer ModelPointerType; @@ -84,6 +85,11 @@ public: typedef typename ModelType::InputSampleType SampleType; typedef typename ModelType::InputListSampleType ListSampleType; + // Dimensionality reduction models + + typedef SOMMap<itk::VariableLengthVector<TInputValue>,itk::Statistics::EuclideanDistanceMetric<itk::VariableLengthVector<TInputValue>>, 2> MapType; + typedef otb::SOM<ListSampleType, MapType> EstimatorType; + typedef otb::SOMModel<InputValueType> SOMModelType; #ifdef OTB_USE_SHARK typedef shark::Autoencoder< shark::TanhNeuron, shark::LinearNeuron> AutoencoderType; @@ -120,9 +126,11 @@ private: #ifdef OTB_USE_SHARK void InitAutoencoderParams(); void InitPCAParams(); + void InitSOMParams(); template <class autoencoderchoice> void TrainAutoencoder(typename ListSampleType::Pointer trainingListSample, std::string modelPath); void TrainPCA(typename ListSampleType::Pointer trainingListSample, std::string modelPath); + void TrainSOM(typename ListSampleType::Pointer trainingListSample, std::string modelPath); #endif //@} }; @@ -132,6 +140,7 @@ private: #ifndef OTB_MANUAL_INSTANTIATION #include "cbLearningApplicationBaseDR.txx" +#include "cbTrainSOM.txx" #ifdef OTB_USE_SHARK #include "cbTrainAutoencoder.txx" #include "cbTrainPCA.txx" diff --git a/include/cbLearningApplicationBaseDR.txx b/include/cbLearningApplicationBaseDR.txx index 562b5e395b5599d11e0885cf2d4c73f9096b908a..885484638cb6d7ac11aeb8060196f85175d056ee 100644 --- a/include/cbLearningApplicationBaseDR.txx +++ b/include/cbLearningApplicationBaseDR.txx @@ -48,7 +48,7 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue> AddParameter(ParameterType_Choice, "model", "moddel to use for the training"); SetParameterDescription("model", "Choice of the dimensionality reduction model to use for the training."); - + InitSOMParams(); #ifdef OTB_USE_SHARK InitAutoencoderParams(); InitPCAParams(); @@ -98,7 +98,11 @@ cbLearningApplicationBaseDR<TInputValue,TOutputValue> // get the name of the chosen machine learning model const std::string modelName = GetParameterString("model"); // call specific train function - + + if(modelName == "som") + { + TrainSOM(trainingListSample,modelPath); + } if(modelName == "autoencoder") { #ifdef OTB_USE_SHARK diff --git a/include/cbTrainSOM.txx b/include/cbTrainSOM.txx new file mode 100644 index 0000000000000000000000000000000000000000..9e2b620dc52349f5a2f16eec97501b44d5998580 --- /dev/null +++ b/include/cbTrainSOM.txx @@ -0,0 +1,95 @@ + +#ifndef cbTrainSOM_txx +#define cbTrainSOM_txx +#include "cbLearningApplicationBaseDR.h" + +namespace otb +{ +namespace Wrapper +{ + +template <class TInputValue, class TOutputValue> +void +cbLearningApplicationBaseDR<TInputValue,TOutputValue> +::InitSOMParams() +{ + + + AddChoice("model.som", "OTB SOM"); + SetParameterDescription("model.som", + "This group of parameters allows setting SOM parameters. " + ); + + + AddParameter(ParameterType_Int, "model.som.sx", "SizeX"); + SetParameterDescription("model.som.sx", "X size of the SOM map"); + MandatoryOff("model.som.sx"); + + AddParameter(ParameterType_Int, "model.som.sy", "SizeY"); + SetParameterDescription("model.som.sy", "Y size of the SOM map"); + MandatoryOff("model.som.sy"); + + AddParameter(ParameterType_Int, "model.som.nx", "NeighborhoodX"); + SetParameterDescription("model.som.nx", "X size of the initial neighborhood in the SOM map"); + MandatoryOff("model.som.nx"); + + AddParameter(ParameterType_Int, "model.som.ny", "NeighborhoodY"); + SetParameterDescription("model.som.ny", "Y size of the initial neighborhood in the SOM map"); + MandatoryOff("model.som.nx"); + + AddParameter(ParameterType_Int, "model.som.ni", "NumberIteration"); + SetParameterDescription("model.som.ni", "Number of iterations for SOM learning"); + MandatoryOff("model.som.ni"); + + AddParameter(ParameterType_Float, "model.som.bi", "BetaInit"); + SetParameterDescription("model.som.bi", "Initial learning coefficient"); + MandatoryOff("model.som.bi"); + + AddParameter(ParameterType_Float, "model.som.bf", "BetaFinal"); + SetParameterDescription("model.som.bf", "Final learning coefficient"); + MandatoryOff("model.som.bf"); + + AddParameter(ParameterType_Float, "model.som.iv", "InitialValue"); + SetParameterDescription("model.som.iv", "Maximum initial neuron weight"); + MandatoryOff("model.som.iv"); + + SetDefaultParameterInt("model.som.sx", 32); + SetDefaultParameterInt("model.som.sy", 32); + SetDefaultParameterInt("model.som.nx", 10); + SetDefaultParameterInt("model.som.ny", 10); + SetDefaultParameterInt("model.som.ni", 5); + SetDefaultParameterFloat("model.som.bi", 1.0); + SetDefaultParameterFloat("model.som.bf", 0.1); + SetDefaultParameterFloat("model.som.iv", 10.0); + + +} + +template <class TInputValue, class TOutputValue> +void cbLearningApplicationBaseDR<TInputValue,TOutputValue> +::TrainSOM(typename ListSampleType::Pointer trainingListSample,std::string modelPath) +{ + + typename SOMModelType::Pointer dimredTrainer = SOMModelType::New(); + dimredTrainer->SetNumberOfIterations(GetParameterInt("model.som.ni")); + dimredTrainer->SetBetaInit(GetParameterFloat("model.som.bi")); + dimredTrainer->SetBetaEnd(GetParameterFloat("model.som.bf")); + dimredTrainer->SetMaxWeight(GetParameterFloat("model.som.iv")); + typename EstimatorType::SizeType size; + size[0]=GetParameterInt("model.som.sx"); + size[1]=GetParameterInt("model.som.sy"); + dimredTrainer->SetMapSize(size); + typename EstimatorType::SizeType radius; + radius[0] = GetParameterInt("model.som.nx"); + radius[1] = GetParameterInt("model.som.ny"); + dimredTrainer->SetNeighborhoodSizeInit(radius); + std::cout << trainingListSample << std::endl; + dimredTrainer->SetListSample(trainingListSample); + dimredTrainer->Train(); + dimredTrainer->Save(modelPath); +} + +} //end namespace wrapper +} //end namespace otb + +#endif