From 01e15a37ba1311b187ec70239a8890357f614aab Mon Sep 17 00:00:00 2001 From: Guillaume Pasero <guillaume.pasero@c-s.fr> Date: Wed, 12 Apr 2017 19:14:01 +0200 Subject: [PATCH] REFAC: fix SVM classes issues --- .../include/otbObjectDetectionClassifier.h | 3 +++ .../include/otbObjectDetectionClassifier.txx | 15 +++++++-------- .../Supervised/include/otbLabelMapClassifier.h | 2 +- .../include/otbLibSVMMachineLearningModel.h | 4 ++-- .../Supervised/include/otbSVMMarginSampler.txx | 4 ++-- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.h b/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.h index 71f4f7263d..c051b33997 100644 --- a/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.h +++ b/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.h @@ -237,6 +237,9 @@ private: /** Step of the detection grid */ unsigned int m_GridStep; + /** classification model */ + ModelPointerType m_Model; + }; diff --git a/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.txx b/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.txx index fa9c4b4427..884cf3f076 100644 --- a/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.txx +++ b/Modules/Detection/ObjectDetection/include/otbObjectDetectionClassifier.txx @@ -36,8 +36,7 @@ PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFun m_NoClassLabel(0), m_GridStep(10) { - // Need 2 inputs : a vector image, and a SVMModel - this->SetNumberOfRequiredInputs(2); + this->SetNumberOfRequiredInputs(1); // Have 2 outputs : the image created by Superclass, a vector data with points this->SetNumberOfRequiredOutputs(3); @@ -86,7 +85,11 @@ void PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType> ::SetModel(ModelType* model) { - this->SetNthInput(1, model); + if (model != m_Model) + { + m_Model = model; + this->Modified(); + } } template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionType> @@ -94,11 +97,7 @@ const typename PersistentObjectDetectionClassifier<TInputImage, TOutputVectorDat PersistentObjectDetectionClassifier<TInputImage, TOutputVectorData, TLabel, TFunctionType> ::GetModel(void) const { - if(this->GetNumberOfInputs()<2) - { - return ITK_NULLPTR; - } - return static_cast<const ModelType*>(this->itk::ProcessObject::GetInput(1)); + return m_Model; } template <class TInputImage, class TOutputVectorData, class TLabel, class TFunctionType> diff --git a/Modules/Learning/Supervised/include/otbLabelMapClassifier.h b/Modules/Learning/Supervised/include/otbLabelMapClassifier.h index 9914e2a15e..75dc291f3d 100644 --- a/Modules/Learning/Supervised/include/otbLabelMapClassifier.h +++ b/Modules/Learning/Supervised/include/otbLabelMapClassifier.h @@ -35,7 +35,7 @@ namespace otb { * \sa otb::SVMModel * \sa itk::InPlaceLabelMapFilter * - * \ingroup OTBSVMLearning + * \ingroup OTBSupervised */ template<class TInputLabelMap> class ITK_EXPORT LabelMapClassifier : diff --git a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h index 1e5cbb964e..fc51f72bf9 100644 --- a/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h +++ b/Modules/Learning/Supervised/include/otbLibSVMMachineLearningModel.h @@ -238,9 +238,9 @@ public: void SetConfidenceMode(unsigned int mode) { - if (mode != m_ConfidenceMode) + if (m_ConfidenceMode != static_cast<ConfidenceMode>(mode) ) { - m_ConfidenceMode = mode; + m_ConfidenceMode = static_cast<ConfidenceMode>(mode); this->m_ConfidenceIndex = this->HasProbabilities(); this->Modified(); } diff --git a/Modules/Learning/Supervised/include/otbSVMMarginSampler.txx b/Modules/Learning/Supervised/include/otbSVMMarginSampler.txx index 04aa7ddc0e..b5ca257fad 100644 --- a/Modules/Learning/Supervised/include/otbSVMMarginSampler.txx +++ b/Modules/Learning/Supervised/include/otbSVMMarginSampler.txx @@ -89,13 +89,13 @@ SVMMarginSampler< TSample, TModel > while (iter != end && iterO != endO) { int i = 0; - typename SVMModelType::InputSampleType modelMeasurement; + typename SVMModelType::InputSampleType modelMeasurement(numberOfComponentsPerSample); measurements = iter.GetMeasurementVector(); // otbMsgDevMacro( << "Loop on components " << svm_type ); for (i=0; i<numberOfComponentsPerSample; ++i) { - modelMeasurement.PushBack(measurements[i]); + modelMeasurement[i] = measurements[i]; } // Get distances to the hyperplanes -- GitLab