Skip to content
Snippets Groups Projects
Commit 4a0e8a57 authored by Julien Michel's avatar Julien Michel
Browse files

BUG: Attempt to fix mean-shift stability bug

parent 079423b6
No related branches found
No related tags found
No related merge requests found
......@@ -66,7 +66,7 @@ public:
for (unsigned int comp = 0; comp < m_ImageDimension; comp++)
{
jointPixel[comp] = index[comp] / m_SpatialBandwidth;
jointPixel[comp] = (index[comp] + m_GlobalShift[comp]) / m_SpatialBandwidth;
}
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
......@@ -76,13 +76,14 @@ public:
}
void Initialize(unsigned int _ImageDimension, unsigned int numberOfComponentsPerPixel_, RealType spatialBandwidth_,
RealType rangeBandwidth_)
RealType rangeBandwidth_, typename TInputImage::IndexType globalShift_)
{
m_ImageDimension = _ImageDimension;
m_NumberOfComponentsPerPixel = numberOfComponentsPerPixel_;
m_SpatialBandwidth = spatialBandwidth_;
m_RangeBandwidth = rangeBandwidth_;
m_OutputSize = m_ImageDimension + m_NumberOfComponentsPerPixel;
m_GlobalShift = globalShift_;
}
unsigned int GetOutputSize() const
......@@ -96,6 +97,7 @@ private:
unsigned int m_OutputSize;
RealType m_SpatialBandwidth;
RealType m_RangeBandwidth;
typename TInputImage::IndexType m_GlobalShift;
};
class KernelUniform
......@@ -108,7 +110,7 @@ public:
RealType operator()(RealType x) const
{
return (x <= 1) ? 1.0 : 0.0;
return (x < 1) ? 1.0 : 0.0;
}
RealType GetRadius(RealType bandwidth) const
......@@ -537,6 +539,8 @@ public:
;
#endif
itkSetMacro(GlobalShift,InputIndexType);
/** Returns the const spatial image output,spatial image output is a displacement map (pixel position after convergence minus pixel index) */
const OutputSpatialImageType * GetSpatialOutput() const;
/** Returns the const spectral image output */
......@@ -656,6 +660,8 @@ private:
BucketImageType m_BucketImage;
#endif
InputIndexType m_GlobalShift;
};
} // end namespace otb
......
......@@ -49,6 +49,7 @@ MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIterati
this->SetNthOutput(1, OutputSpatialImageType::New());
this->SetNthOutput(2, OutputIterationImageType::New());
this->SetNthOutput(3, OutputLabelImageType::New());
m_GlobalShift.Fill(0);
}
template<class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
......@@ -176,7 +177,14 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
// Initializes the spatial radius from kernel bandwidth
m_SpatialRadius.Fill(m_Kernel.GetRadius(m_SpatialBandwidth));
inputRequestedRegion.PadByRadius(m_SpatialRadius);
InputSizeType margin;
for(unsigned int comp = 0; comp < ImageDimension; ++comp)
{
margin[comp] = m_MaxIterationNumber * m_SpatialRadius[comp];
}
inputRequestedRegion.PadByRadius(margin);
// Crop the input requested region at the input's largest possible region
if (inputRequestedRegion.Crop(inPtr->GetLargestPossibleRegion()))
......@@ -242,7 +250,7 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
jointImageFunctor->SetInput(inputPtr);
jointImageFunctor->GetFunctor().Initialize(ImageDimension, m_NumberOfComponentsPerPixel, m_SpatialBandwidth,
m_RangeBandwidth);
m_RangeBandwidth, m_GlobalShift);
jointImageFunctor->Update();
m_JointImage = jointImageFunctor->GetOutput();
......@@ -350,14 +358,14 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
// Calculates current pixel neighborhood region, restricted to the output image region
for (unsigned int comp = 0; comp < ImageDimension; ++comp)
{
inputIndex[comp] = jointPixel[comp] * m_SpatialBandwidth;
inputIndex[comp] = vcl_floor(jointPixel[comp] * m_SpatialBandwidth+ 0.5) - m_GlobalShift[comp];
regionIndex[comp] = vcl_max(static_cast<long int> (outputRegion.GetIndex().GetElement(comp)),
static_cast<long int> (inputIndex[comp] - m_SpatialRadius[comp]));
static_cast<long int> (inputIndex[comp] - m_SpatialRadius[comp] - 1));
const long int indexRight = vcl_min(
static_cast<long int> (outputRegion.GetIndex().GetElement(comp)
+ outputRegion.GetSize().GetElement(comp) - 1),
static_cast<long int> (inputIndex[comp] + m_SpatialRadius[comp]));
static_cast<long int> (inputIndex[comp] + m_SpatialRadius[comp] + 1));
regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int> (regionIndex[comp]) + 1);
}
......@@ -367,7 +375,7 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
neighborhoodRegion.SetSize(regionSize);
RealType weightSum = 0;
RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel);
RealVector jointNeighbor(ImageDimension + m_NumberOfComponentsPerPixel), shifts(ImageDimension + m_NumberOfComponentsPerPixel);
// An iterator on the neighborhood of the current pixel (in joint
// spatial-range domain)
......@@ -377,15 +385,15 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
it.GoToBegin();
while (!it.IsAtEnd())
{
jointNeighbor.SetData(const_cast<RealType*> (it.GetPixelPointer()));
jointNeighbor = it.Get();
// Compute the squared norm of the difference
// This is the L2 norm, TODO: replace by the templated norm
RealType norm2 = 0;
for (unsigned int comp = 0; comp < jointDimension; comp++)
{
const RealType d = jointNeighbor[comp] - jointPixel[comp];
norm2 += d * d;
shifts[comp] = jointNeighbor[comp] - jointPixel[comp];
norm2 += shifts[comp] * shifts[comp];
}
// Compute pixel weight from kernel
......@@ -427,7 +435,7 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
// Update mean shift vector
for (unsigned int comp = 0; comp < jointDimension; comp++)
{
meanShiftVector[comp] += weight * jointNeighbor[comp];
meanShiftVector[comp] += weight * shifts[comp];
}
++it;
......@@ -437,7 +445,7 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
{
for (unsigned int comp = 0; comp < jointDimension; comp++)
{
meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
meanShiftVector[comp] = meanShiftVector[comp] / weightSum;
}
}
}
......@@ -609,7 +617,7 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
jointPixel = jointIt.Get(); // Pixel in the joint spatial-range domain
// index of the currently processed output pixel
InputIndexType const& currentIndex = jointIt.GetIndex();
InputIndexType currentIndex = jointIt.GetIndex();
// Number of points currently in the pointList
unsigned int pointCount = 0; // Note: used only in mode search optimization
......@@ -710,7 +718,7 @@ void MeanShiftSmoothingImageFilter<TInputImage, TOutputImage, TKernel, TOutputIt
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp];
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp] - m_GlobalShift[comp];
}
rangeIt.Set(rangePixel);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment