Commit a632f026 authored by Jonathan Guinet's avatar Jonathan Guinet
Browse files

ENH: Add connected component step in MeanShift Segmentation filter.

parent 2811ba3e
......@@ -23,9 +23,58 @@
#include "otbMeanShiftSmoothingImageFilter.h"
#include "otbLabelImageRegionMergingFilter.h"
#include "otbLabelImageRegionPruningFilter.h"
#include "itkRelabelComponentImageFilter.h"
#include <itkConnectedComponentFunctorImageFilter.h>
namespace otb {
namespace Functor
{
template<class TInput>
class ITK_EXPORT ConnectedLabelFunctor
{
public:
typedef ConnectedLabelFunctor Self;
std::string GetNameOfClass()
{
return "ConnectedLabelFunctor";
}
inline bool operator()( TInput &p1, TInput &p2)
{
//return static_cast<bool> (0);
return static_cast<bool> (p1==p2);
}
ConnectedLabelFunctor()
{
}
~ConnectedLabelFunctor()
{
}
private:
ConnectedLabelFunctor(const Self &); //purposely not implemented
void operator =(const Self &); //purposely not implemented
};
} // end of Functor namespace
/** \class MeanShiftSegmentationFilter
*
* Performs segmentation of an image by chaining a mean shift filter and region
......@@ -35,7 +84,7 @@ namespace otb {
template <class TInputImage, class TOutputLabelImage, class TOutputClusteredImage = TInputImage, class TKernel = KernelUniform>
class MeanShiftSegmentationFilter : public itk::ImageToImageFilter<TInputImage, TOutputLabelImage>
class ITK_EXPORT MeanShiftSegmentationFilter : public itk::ImageToImageFilter<TInputImage, TOutputLabelImage>
{
public:
/** Standard Self typedef */
......@@ -67,12 +116,21 @@ public:
typedef typename MeanShiftFilterType::Pointer MeanShiftFilterPointerType;
// Region merging filter
typedef typename MeanShiftFilterType::OutputLabelImageType InputLabelImageType;
typedef typename MeanShiftFilterType::LabelType InputLabelPixelType;
typedef LabelImageRegionMergingFilter<InputLabelImageType, MeanShiftFilteredImageType,
OutputLabelImageType, OutputClusteredImageType> RegionMergingFilterType;
typedef typename RegionMergingFilterType::Pointer RegionMergingFilterPointerType;
typedef LabelImageRegionPruningFilter<OutputLabelImageType,OutputClusteredImageType,
OutputLabelImageType, OutputClusteredImageType> RegionPruningFilterType;
typedef typename RegionPruningFilterType::Pointer RegionPruningFilterPointerType;
typedef typename RegionPruningFilterType::Pointer RegionPruningFilterPointerType;
typedef Functor::ConnectedLabelFunctor<InputLabelPixelType> LabelFunctorType;
typedef itk::ConnectedComponentFunctorImageFilter<InputLabelImageType,InputLabelImageType,LabelFunctorType> RelabelComponentFilterType;
typedef typename RelabelComponentFilterType::Pointer RelabelComponentFilterPointerType;
/** Sets the spatial bandwidth (or radius in the case of a uniform kernel)
......@@ -101,6 +159,7 @@ public:
otbSetObjectMemberMacro(RegionPruningFilter,MinRegionSize,RealType);
otbGetObjectMemberMacro(RegionPruningFilter,MinRegionSize,RealType);
/** Returns the const image of region labels */
const OutputLabelImageType * GetLabelOutput() const;
/** Returns the image of region labels */
......@@ -121,9 +180,10 @@ protected:
private:
MeanShiftFilterPointerType m_MeanShiftFilter;
RegionMergingFilterPointerType m_RegionMergingFilter;
RegionPruningFilterPointerType m_RegionPruningFilter;
MeanShiftFilterPointerType m_MeanShiftFilter;
RegionMergingFilterPointerType m_RegionMergingFilter;
RegionPruningFilterPointerType m_RegionPruningFilter;
RelabelComponentFilterPointerType m_RelabelFilter;
};
......
......@@ -29,6 +29,7 @@ MeanShiftSegmentationFilter<TInputImage, TOutputLabelImage, TOutputClusteredImag
m_MeanShiftFilter = MeanShiftFilterType::New();
m_RegionMergingFilter = RegionMergingFilterType::New();
m_RegionPruningFilter = RegionPruningFilterType::New();
m_RelabelFilter = RelabelComponentFilterType::New();
this->SetMinRegionSize(100);
this->SetNumberOfOutputs(2);
this->SetNthOutput(0,TOutputLabelImage::New());
......@@ -78,12 +79,17 @@ MeanShiftSegmentationFilter<TInputImage, TOutputLabelImage, TOutputClusteredImag
::GenerateData()
{
this->m_MeanShiftFilter->SetInput(this->GetInput());
this->m_RegionMergingFilter->SetInputLabelImage(this->m_MeanShiftFilter->GetLabelOutput());
// Relabel output to avoid same label assigned to discontinuous areas
m_RelabelFilter->SetInput(this->m_MeanShiftFilter->GetLabelOutput());
this->m_RegionMergingFilter->SetInputLabelImage(this->m_RelabelFilter->GetOutput());
this->m_RegionMergingFilter->SetInputSpectralImage(this->m_MeanShiftFilter->GetRangeOutput());
this->m_RegionMergingFilter->SetRangeBandwidth(this->GetRangeBandwidth());
//std::cout << "MinRegionSize " << this->m_RegionPruningFilter->GetMinRegionSize() << std::endl;
if (this->GetMinRegionSize() == 0)
{
m_RegionMergingFilter->GraftNthOutput(0, this->GetLabelOutput());
m_RegionMergingFilter->GraftNthOutput(1, this->GetClusteredOutput());
this->m_RegionMergingFilter->Update();
......@@ -95,9 +101,9 @@ MeanShiftSegmentationFilter<TInputImage, TOutputLabelImage, TOutputClusteredImag
this->m_RegionPruningFilter->SetInputLabelImage(this->m_RegionMergingFilter->GetLabelOutput());
this->m_RegionPruningFilter->SetInputSpectralImage(this->m_RegionMergingFilter->GetClusteredOutput());
m_RegionPruningFilter->GraftNthOutput(0, this->GetLabelOutput());
m_RegionPruningFilter->GraftNthOutput(1, this->GetClusteredOutput());
this->m_RegionPruningFilter->Update();
this->GraftNthOutput(0, m_RegionPruningFilter->GetLabelOutput());
this->GraftNthOutput(1, m_RegionPruningFilter->GetClusteredOutput());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment