Skip to content
Snippets Groups Projects
Commit 63d69d6c authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

ENH: Add Supervised or Unsupervised classifier selection for learning.

parent bec02c1f
No related branches found
No related tags found
No related merge requests found
......@@ -160,15 +160,31 @@ protected:
std::string modelPath);
/** Init method that creates all the parameters for machine learning models */
void DoInit();
void DoInit() ITK_OVERRIDE;
/** Flag to switch between classification and regression mode.
* False by default, child classes may change it in their constructor */
bool m_RegressionFlag;
private:
/** enum use to selected classifier category */
enum ClassifierCategory {
Supervised,
Unsupervised
};
/** Enum to switch between unsupervised or supervised classification.
* Supervised by default, child classes may change it in their constructor */
ClassifierCategory m_ClassifierCategory;
private:
/** Specific Init and Train methods for each machine learning model */
/** Init Parameters for Supervised Classifier */
void InitSupervisedClassifierParams();
/** Init Parameters for Unsupervised Classifier */
void InitUnsupervisedClassifierParams();
//@{
#ifdef OTB_USE_LIBSVM
void InitLibSVMParams();
......
......@@ -28,7 +28,7 @@ namespace Wrapper
template <class TInputValue, class TOutputValue>
LearningApplicationBase<TInputValue,TOutputValue>
::LearningApplicationBase() : m_RegressionFlag(false)
::LearningApplicationBase() : m_RegressionFlag(false), m_ClassifierCategory(Supervised)
{
}
......@@ -50,8 +50,25 @@ LearningApplicationBase<TInputValue,TOutputValue>
AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
switch(m_ClassifierCategory)
{
case Unsupervised:
InitUnsupervisedClassifierParams();
break;
case Supervised:
default :
InitSupervisedClassifierParams();
}
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitSupervisedClassifierParams()
{
//Group LibSVM
#ifdef OTB_USE_LIBSVM
#ifdef OTB_USE_LIBSVM
InitLibSVMParams();
#endif
......@@ -78,7 +95,14 @@ LearningApplicationBase<TInputValue,TOutputValue>
InitSharkRandomForestsParams();
InitSharkKMeansParams();
#endif
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitUnsupervisedClassifierParams()
{
}
template <class TInputValue, class TOutputValue>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment