diff --git a/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.h index f10a448154950590c6aa7d8d0d4f6a85e65f0083..8108ce0f82c67ae3b340ada954d0e7b7b604220a 100644 --- a/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.h +++ b/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.h @@ -115,6 +115,9 @@ public: itkGetMacro(P, double); itkSetMacro(P, double); + itkGetMacro(ParameterOptimization, bool); + itkSetMacro(ParameterOptimization, bool); + protected: /** Constructor */ SVMMachineLearningModel(); @@ -141,6 +144,7 @@ private: double m_C; double m_Nu; double m_P; + bool m_ParameterOptimization; }; } // end namespace otb diff --git a/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.txx index b4563da2fbb796691c4caba2941ba92b4683345e..53712378ff356968b6a2b4e6e3a8a99e578000f5 100644 --- a/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.txx +++ b/Code/UtilitiesAdapters/OpenCV/otbSVMMachineLearningModel.txx @@ -63,7 +63,18 @@ SVMMachineLearningModel<TInputValue,TOutputValue> CvSVMParams params( m_SVMType, m_KernelType, m_Degree, m_Gamma, m_Coef0, m_C, m_Nu, m_P, NULL , term_crit ); // Train the SVM - m_SVMModel->train(samples, labels, cv::Mat(), cv::Mat(), params); + if (!m_ParameterOptimization) + m_SVMModel->train(samples, labels, cv::Mat(), cv::Mat(), params); + else + //Trains SVM with optimal parameters. + //train_auto(const Mat& trainData, const Mat& responses, const Mat& varIdx, const Mat& sampleIdx, + //CvSVMParams params, int k_fold=10, CvParamGrid Cgrid=CvSVM::get_default_grid(CvSVM::C), + //CvParamGrid gammaGrid=CvSVM::get_default_grid(CvSVM::GAMMA), + //CvParamGrid pGrid=CvSVM::get_default_grid(CvSVM::P), CvParamGrid nuGrid=CvSVM::get_default_grid(CvSVM::NU), + //CvParamGrid coeffGrid=CvSVM::get_default_grid(CvSVM::COEF), CvParamGrid degreeGrid=CvSVM::get_default_grid(CvSVM::DEGREE), + //bool balanced=false) + //We used default parameters grid. If not enough, those grids should be expose to the user. + m_SVMModel->train_auto(samples, labels, cv::Mat(), cv::Mat(), params); } template <class TInputValue, class TOutputValue>