diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h index 1de069ab6d0fad826d03c3774b9b3ff304e28d0e..7da0aaad12dedfcb7bfebcba807b0d2b6fef94d0 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h @@ -116,8 +116,10 @@ public: */ TargetSampleType Predict(const InputSampleType& input, ConfidenceValueType *quality = ITK_NULLPTR) const; - - virtual unsigned int GetDimension() {return 1;}; /// This method is used to determine the output vector size after dimensionality reduction, and should be overrided for all machine learning models used for dimensionality reduction. This method is not used for classification and regression + itkSetMacro(Dimension,unsigned int); + itkGetMacro(Dimension,unsigned int); + + // virtual unsigned int GetDimension() {return 1;}; /// This method is used to determine the output vector size after dimensionality reduction, and should be overrided for all machine learning models used for dimensionality reduction. This method is not used for classification and regression /** Predict a batch of samples (InputListSampleType) @@ -186,7 +188,9 @@ protected: /** Input list sample */ typename InputListSampleType::Pointer m_InputListSample; - + + typename InputListSampleType::Pointer m_ValidationListSample; + /** Target list sample */ typename TargetListSampleType::Pointer m_TargetListSample; @@ -195,6 +199,9 @@ protected: /** flag to choose between classification and regression modes */ bool m_RegressionMode; + /** Output Dimension of the model, used by Dimensionality Reduction models*/ + + /** flag that indicates if the model supports regression, child * classes should modify it in their constructor if they support * regression mode */ @@ -205,7 +212,8 @@ protected: /** Is DoPredictBatch multi-threaded ? */ bool m_IsDoPredictBatchMultiThreaded; - + unsigned int m_Dimension; + private: /** Actual implementation of BatchPredicition * Default implementation will call DoPredict iteratively