Skip to content
Snippets Groups Projects
Commit ef2ae905 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

REFAC: generic LabelMapClassifier works with any ML model

parent 4997b3a0
Branches
Tags
No related merge requests found
......@@ -18,17 +18,17 @@
* limitations under the License.
*/
#ifndef otbLabelMapSVMClassifier_h
#define otbLabelMapSVMClassifier_h
#ifndef otbLabelMapClassifier_h
#define otbLabelMapClassifier_h
#include "itkInPlaceLabelMapFilter.h"
#include "otbSVMModel.h"
#include "otbMachineLearningModel.h"
#include "itkListSample.h"
#include "otbAttributesMapLabelObject.h"
namespace otb {
/** \class LabelMapSVMClassifier
/** \class LabelMapClassifier
* \brief Classify each LabelObject of the input LabelMap in place
*
* \sa otb::AttributesMapLabelObject
......@@ -38,12 +38,12 @@ namespace otb {
* \ingroup OTBSVMLearning
*/
template<class TInputLabelMap>
class ITK_EXPORT LabelMapSVMClassifier :
class ITK_EXPORT LabelMapClassifier :
public itk::InPlaceLabelMapFilter<TInputLabelMap>
{
public:
/** Standard class typedefs. */
typedef LabelMapSVMClassifier Self;
typedef LabelMapClassifier Self;
typedef itk::InPlaceLabelMapFilter<TInputLabelMap> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
......@@ -56,27 +56,26 @@ public:
typedef typename LabelObjectType::AttributesValueType AttributesValueType;
typedef typename LabelObjectType::ClassLabelType ClassLabelType;
typedef std::vector<AttributesValueType> MeasurementVectorType;
typedef Functor::AttributesMapMeasurementFunctor
<LabelObjectType, MeasurementVectorType> MeasurementFunctorType;
/** ImageDimension constants */
itkStaticConstMacro(InputImageDimension, unsigned int,
TInputLabelMap::ImageDimension);
/** Type definitions for the SVM Model. */
typedef SVMModel<AttributesValueType, ClassLabelType> SVMModelType;
typedef typename SVMModelType::Pointer SVMModelPointer;
/** Type definitions for the learning model. */
typedef MachineLearningModel<AttributesValueType, ClassLabelType> ModelType;
typedef typename ModelType::Pointer ModelPointer;
typedef typename ModelType::InputSampleType MeasurementVectorType;
typedef Functor::AttributesMapMeasurementFunctor
<LabelObjectType, MeasurementVectorType> MeasurementFunctorType;
/** Standard New method. */
itkNewMacro(Self);
/** Runtime information support. */
itkTypeMacro(LabelMapSVMClassifier,
itkTypeMacro(LabelMapClassifier,
itk::InPlaceLabelMapFilter);
itkSetObjectMacro(Model, SVMModelType);
itkSetObjectMacro(Model, ModelType);
void SetMeasurementFunctor(const MeasurementFunctorType& functor)
{
......@@ -89,8 +88,8 @@ public:
}
protected:
LabelMapSVMClassifier();
~LabelMapSVMClassifier() ITK_OVERRIDE {};
LabelMapClassifier();
~LabelMapClassifier() ITK_OVERRIDE {};
void ThreadedProcessLabelObject( LabelObjectType * labelObject ) ITK_OVERRIDE;
......@@ -98,11 +97,11 @@ protected:
private:
LabelMapSVMClassifier(const Self&); //purposely not implemented
LabelMapClassifier(const Self&); //purposely not implemented
void operator=(const Self&); //purposely not implemented
/** The SVM model used for classification */
SVMModelPointer m_Model;
/** The learning model used for classification */
ModelPointer m_Model;
/** The functor used to build the measurement vector */
MeasurementFunctorType m_MeasurementFunctor;
......@@ -112,7 +111,7 @@ private:
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbLabelMapSVMClassifier.txx"
#include "otbLabelMapClassifier.txx"
#endif
#endif
......
......@@ -18,27 +18,27 @@
* limitations under the License.
*/
#ifndef otbLabelMapSVMClassifier_txx
#define otbLabelMapSVMClassifier_txx
#ifndef otbLabelMapClassifier_txx
#define otbLabelMapClassifier_txx
#include "otbLabelMapSVMClassifier.h"
#include "otbLabelMapClassifier.h"
namespace otb {
template <class TInputImage>
LabelMapSVMClassifier<TInputImage>
::LabelMapSVMClassifier()
LabelMapClassifier<TInputImage>
::LabelMapClassifier()
{
// Force to single-threaded (SVMModel is not thread-safe)
// Force to single-threaded in case the learning model is not thread safe
// This way, we benefit of the LabelMapFilter design and only need
// to implement ThreadedProcessLabelObject
this->SetNumberOfThreads(1);
this->SetNumberOfThreads(1); // TODO : check if still needed
}
template<class TInputImage>
void
LabelMapSVMClassifier<TInputImage>
LabelMapClassifier<TInputImage>
::ReleaseInputs( )
{
// by pass itk::InPlaceLabelMapFilter::ReleaseInputs() implementation,
......@@ -48,10 +48,10 @@ LabelMapSVMClassifier<TInputImage>
template<class TInputImage>
void
LabelMapSVMClassifier<TInputImage>
LabelMapClassifier<TInputImage>
::ThreadedProcessLabelObject( LabelObjectType * labelObject )
{
ClassLabelType classLabel = m_Model->EvaluateLabel(m_MeasurementFunctor(labelObject));
ClassLabelType classLabel = (m_Model->Predict(m_MeasurementFunctor(labelObject)))[0];
labelObject->SetClassLabel(classLabel);
}
......
......@@ -29,6 +29,7 @@ ENABLE_SHARED
OTBCommon
OTBITK
OTBImageBase
OTBLabelMap
OPTIONAL_DEPENDS
OTBOpenCV
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment