Skip to content
Snippets Groups Projects
Commit 27de89c9 authored by Sebastien Harasse's avatar Sebastien Harasse
Browse files

REFAC: Mean shift. Cleaned up CalculateMeanShiftVector.

parent 2aff87a9
No related branches found
No related tags found
No related merge requests found
......@@ -293,24 +293,19 @@ void
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
{
RealVector weightingMeanShiftVector;
double sum=0;
RealVector jointNeighbor;
meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
meanShiftVector.Fill(0.);
weightingMeanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
weightingMeanShiftVector.Fill(0.);
double neighborhoodValue;
double value;
bool isInside;
jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
RealType weightSum = 0;
InputPixelType inputPixel;
InputIndexType inputIndex;
InputIndexType regionIndex;
InputSizeType regionSize;
RegionType neighborhoodRegion;
meanShiftVector.Fill(0.);
// Calculates current pixel neighborhood region, restricted to the output image region
for(unsigned int comp = 0; comp < ImageDimension; ++comp)
{
......@@ -320,6 +315,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
regionIndex[comp] = vcl_max(static_cast<long int>(outputRegion.GetIndex().GetElement(comp)), static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp]));
indexRight = vcl_min(static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1), static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp]));
// regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1));
regionSize[comp] = indexRight - regionIndex[comp] + 1;
}
......@@ -329,79 +325,68 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// An iterator on the neighborhood of the current pixel
itk::ImageRegionConstIteratorWithIndex<InputImageType> it(inputImage, neighborhoodRegion);
//std::cout << neighborhoodRegion << std::endl;
it.GoToBegin();
while(!it.IsAtEnd())
{
isInside = true;
inputIndex = it.GetIndex();
inputPixel = it.Get();
RealVector diff;
RealType norm2;
RealType weight;
double diff, el;
el = 0;
// Write the current pixel of the neighborhood in the joint spatial-range domain
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
neighborhoodValue = it.GetIndex().GetElement(comp);
el += (neighborhoodValue - jointPixel[comp]) * (neighborhoodValue - jointPixel[comp]);
jointNeighbor[comp] = inputIndex[comp];
}
diff = el / (m_SpatialBandwidth * m_SpatialBandwidth);
isInside = diff < 1.0;
if (isInside)
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
diff = 0;
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
jointNeighbor[ImageDimension + comp] = inputPixel[comp];
}
neighborhoodValue = it.Get().GetElement(comp);
el = (neighborhoodValue - jointPixel[ImageDimension + comp]) / m_RangeBandwidth;
// Calculate the squared norm of the difference
diff = jointNeighbor - jointPixel;
diff += el * el;
// Scale diff vector elements by the bandwidth
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
diff[comp] /= m_SpatialBandwidth;
}
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
diff[ImageDimension + comp] /= m_RangeBandwidth;
}
}
// Compute the squared norm of the difference
// This is the L_inf norm, TODO: replace by the templated norm
norm2 = 0;
for (unsigned int comp = 0; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++)
{
norm2 += vcl_max(norm2, vcl_abs(diff[comp]));
}
isInside = diff < 1.0;
norm2 *= norm2;
// Compute pixel weight from kernel
// TODO : replace by the templated kernel
weight = (norm2 <= 1.0)? 1.0 : 0.0;
if (isInside)
{
// Update sum of weights
weightSum += weight;
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
neighborhoodValue = it.GetIndex().GetElement(comp);
value = 1;
meanShiftVector[comp] += (neighborhoodValue) * value;
weightingMeanShiftVector[comp] += value;
}
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
neighborhoodValue = it.Get().GetElement(comp);
value = 1;
meanShiftVector[ImageDimension + comp] += (neighborhoodValue) * value;
weightingMeanShiftVector[ImageDimension + comp] += value;
}
}
// Update mean shift vector
meanShiftVector += weight * jointNeighbor;
++it;
}
//Normalize vector by kernel total weight
for(unsigned int comp=0; comp < ImageDimension; comp++)
if(weightSum > 0)
{
if( weightingMeanShiftVector[comp]>0)
meanShiftVector[comp]=meanShiftVector[comp]/weightingMeanShiftVector[comp]-jointPixel[comp];
else
meanShiftVector[comp]=0;
meanShiftVector /= weightSum;
meanShiftVector -= jointPixel;
}
for(unsigned int comp=0; comp<m_NumberOfComponentsPerPixel; comp++)
{
if( weightingMeanShiftVector[ImageDimension + comp]>0)
meanShiftVector[ImageDimension + comp] = meanShiftVector[ImageDimension + comp] / weightingMeanShiftVector[ImageDimension + comp] - jointPixel[ImageDimension + comp];
else
meanShiftVector[ImageDimension + comp] = 0;
}
else
meanShiftVector.Fill(0);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment