diff --git a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h index 86b2d5e4e303b1997e0ebe48889761ba31f16624..50555eaf5b44689e60ded7a154af4e1423a647dd 100644 --- a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h +++ b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.h @@ -23,6 +23,7 @@ #include "itkFixedArray.h" #include "itkListSample.h" #include "otbMachineLearningModel.h" +#include "itkVariableSizeMatrix.h" class CvRTrees; @@ -48,6 +49,10 @@ public: typedef TTargetValue TargetValueType; typedef itk::FixedArray<TargetValueType,1> TargetSampleType; typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + + // Other + typedef itk::VariableSizeMatrix<float> VariableImportanceMatrixType; + //opencv typedef typedef CvRTrees RFType; @@ -159,11 +164,9 @@ public: itkGetMacro(TerminationCriteria, int); itkSetMacro(TerminationCriteria, int); - // cv::Mat GetVariableImportance() - // { - // return m_RFModel->getVarImportance(); - // } - + /** Returns a matrix containing variable importance */ + VariableImportanceMatrixType GetVariableImportance(); + float GetTrainError(); protected: diff --git a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx index 5654b2ba602e151c12d62aaead79a682f534090c..001096c81b3a56efc05217f9f207cda7f74572d6 100644 --- a/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx +++ b/Code/UtilitiesAdapters/OpenCVAdapters/otbRandomForestsMachineLearningModel.txx @@ -183,6 +183,24 @@ RandomForestsMachineLearningModel<TInputValue,TOutputValue> return false; } +template <class TInputValue, class TOutputValue> +typename RandomForestsMachineLearningModel<TInputValue,TOutputValue> +::VariableImportanceMatrixType +RandomForestsMachineLearningModel<TInputValue,TOutputValue> +::GetVariableImportance() +{ + cv::Mat cvMat = m_RFModel->getVarImportance(); + VariableImportanceMatrixType itkMat(cvMat.rows,cvMat.cols); + for(unsigned int i =0; i<cvMat.rows; i++) + { + for(unsigned int j =0; j<cvMat.cols; j++) + { + itkMat(i,j)=cvMat.at<float>(i,j); + } + } + return itkMat; +} + template <class TInputValue, class TOutputValue> void