diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h index 145aecb7382e8c3a740885e73376156830034502..8c451322064e3b70060cf72543be33c358fd6183 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h @@ -160,15 +160,31 @@ protected: std::string modelPath); /** Init method that creates all the parameters for machine learning models */ - void DoInit(); + void DoInit() ITK_OVERRIDE; /** Flag to switch between classification and regression mode. * False by default, child classes may change it in their constructor */ bool m_RegressionFlag; -private: + /** enum use to selected classifier category */ + enum ClassifierCategory { + Supervised, + Unsupervised + }; + + /** Enum to switch between unsupervised or supervised classification. + * Supervised by default, child classes may change it in their constructor */ + ClassifierCategory m_ClassifierCategory; +private: /** Specific Init and Train methods for each machine learning model */ + + /** Init Parameters for Supervised Classifier */ + void InitSupervisedClassifierParams(); + + /** Init Parameters for Unsupervised Classifier */ + void InitUnsupervisedClassifierParams(); + //@{ #ifdef OTB_USE_LIBSVM void InitLibSVMParams(); diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx index f3731e9495abbf638a28d6d7b516935ffc043eda..59ffcb2b4dfd9914b9df51bd3e9ba9a5490911bd 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx @@ -28,7 +28,7 @@ namespace Wrapper template <class TInputValue, class TOutputValue> LearningApplicationBase<TInputValue,TOutputValue> -::LearningApplicationBase() : m_RegressionFlag(false) +::LearningApplicationBase() : m_RegressionFlag(false), m_ClassifierCategory(Supervised) { } @@ -50,8 +50,25 @@ LearningApplicationBase<TInputValue,TOutputValue> AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training"); SetParameterDescription("classifier", "Choice of the classifier to use for the training."); + switch(m_ClassifierCategory) + { + case Unsupervised: + InitUnsupervisedClassifierParams(); + break; + case Supervised: + default : + InitSupervisedClassifierParams(); + } +} + +template <class TInputValue, class TOutputValue> +void +LearningApplicationBase<TInputValue,TOutputValue> +::InitSupervisedClassifierParams() +{ + //Group LibSVM -#ifdef OTB_USE_LIBSVM +#ifdef OTB_USE_LIBSVM InitLibSVMParams(); #endif @@ -78,7 +95,14 @@ LearningApplicationBase<TInputValue,TOutputValue> InitSharkRandomForestsParams(); InitSharkKMeansParams(); #endif - +} + +template <class TInputValue, class TOutputValue> +void +LearningApplicationBase<TInputValue,TOutputValue> +::InitUnsupervisedClassifierParams() +{ + } template <class TInputValue, class TOutputValue>