Skip to content
Snippets Groups Projects
Commit e12cfb66 authored by Sebastien Harasse's avatar Sebastien Harasse
Browse files

REFAC: Mean shift. Allocated image in joint space-range domain with scaled values.

parent 27de89c9
No related branches found
No related tags found
No related merge requests found
......@@ -141,7 +141,9 @@ public:
itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension);
typedef itk::VariableLengthVector<RealType> RealVector;
typedef itk::VariableLengthVector<RealType> RealVector;
typedef itk::VectorImage<RealType, InputImageType::ImageDimension> RealVectorImageType;
/** Setters / Getters */
itkSetMacro(SpatialBandwidth, RealType);
......@@ -215,7 +217,7 @@ protected:
/** PrintSelf method */
virtual void PrintSelf(std::ostream& os, itk::Indent indent) const;
virtual void CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector);
virtual void CalculateMeanShiftVector(typename RealVectorImageType::Pointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector);
private:
MeanShiftImageFilter2(const Self &); //purposely not implemented
......
......@@ -291,85 +291,83 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
void
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
{
RealVector jointNeighbor;
jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
RealType weightSum = 0;
InputPixelType inputPixel;
InputIndexType inputIndex;
InputIndexType regionIndex;
InputSizeType regionSize;
RegionType neighborhoodRegion;
meanShiftVector.Fill(0.);
jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
// Calculates current pixel neighborhood region, restricted to the output image region
for(unsigned int comp = 0; comp < ImageDimension; ++comp)
{
long int indexRight;
inputIndex[comp] = jointPixel[comp];
inputIndex[comp] = jointPixel[comp] * m_SpatialBandwidth;
regionIndex[comp] = vcl_max(static_cast<long int>(outputRegion.GetIndex().GetElement(comp)), static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp]));
indexRight = vcl_min(static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1), static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp]));
// regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1));
regionSize[comp] = indexRight - regionIndex[comp] + 1;
regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1));
}
neighborhoodRegion.SetIndex(regionIndex); // TODO Handle region borders
neighborhoodRegion.SetSize(regionSize); //TODO Add +1 for each dimension
// An iterator on the neighborhood of the current pixel
itk::ImageRegionConstIteratorWithIndex<InputImageType> it(inputImage, neighborhoodRegion);
// An iterator on the neighborhood of the current pixel (in joint
// spatial-range domain)
itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> it(jointImage, neighborhoodRegion);
//std::cout << neighborhoodRegion << std::endl;
it.GoToBegin();
while(!it.IsAtEnd())
{
inputIndex = it.GetIndex();
inputPixel = it.Get();
RealVector diff;
RealType norm2;
RealType weight;
// Write the current pixel of the neighborhood in the joint spatial-range domain
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
jointNeighbor[comp] = inputIndex[comp];
}
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
jointNeighbor[ImageDimension + comp] = inputPixel[comp];
}
jointNeighbor = it.Get();
// Calculate the squared norm of the difference
diff = jointNeighbor - jointPixel;
// Scale diff vector elements by the bandwidth
// Compute the squared norm of the difference
// This is the L2 norm, TODO: replace by the templated norm
norm2 = diff.GetSquaredNorm();
// Compute pixel weight from kernel
// TODO : replace by the templated kernel
weight = (norm2 <= 1.0)? 1.0 : 0.0;
/*
// The following code is an alternative way to compute norm2 and weight
// It separates the norms of spatial and range elements
RealType spatialNorm2;
RealType rangeNorm2;
spatialNorm2 = 0;
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
diff[comp] /= m_SpatialBandwidth;
}
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
diff[ImageDimension + comp] /= m_RangeBandwidth;
spatialNorm2 += diff[comp] * diff[comp];
}
// Compute the squared norm of the difference
// This is the L_inf norm, TODO: replace by the templated norm
norm2 = 0;
for (unsigned int comp = 0; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++)
if(spatialNorm2 >= 1.0)
{
norm2 += vcl_max(norm2, vcl_abs(diff[comp]));
weight = 0;
}
norm2 *= norm2;
else
{
rangeNorm2 = 0;
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
rangeNorm2 += diff[ImageDimension + comp] * diff[ImageDimension + comp];
}
// Compute pixel weight from kernel
// TODO : replace by the templated kernel
weight = (norm2 <= 1.0)? 1.0 : 0.0;
weight = (rangeNorm2 <= 1.0)? 1.0 : 0.0;
}
*/
// Update sum of weights
weightSum += weight;
......@@ -425,17 +423,54 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typename OutputMetricImageType::PixelType metricPixel;
typename OutputIterationImageType::PixelType iterationPixel;
InputIndexType index;
// Pixel in the joint spatial-range domain
RealVector jointPixel;
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
// Declare and allocate an image in the joint space-range domain containing
// scaled values
typename RealVectorImageType::Pointer jointImage = RealVectorImageType::New();
jointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel);
jointImage->SetRegions(outputRegionForThread);
jointImage->Allocate();
typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
JointImageIteratorType jointIt(jointImage, outputRegionForThread);
InputIteratorWithIndexType inputIt(input, inputRegionForThread);
//Initialize the joint image with scaled values
inputIt.GoToBegin();
jointIt.GoToBegin();
while (!inputIt.IsAtEnd())
{
inputPixel = inputIt.Get();
index = inputIt.GetIndex();
jointPixel = jointIt.Get();
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
jointPixel[comp] = index[comp] / m_SpatialBandwidth;
}
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth;
}
jointIt.Set(jointPixel);
++inputIt;
++jointIt;
}
OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
OutputIteratorType spatialIt(spatialOutput, outputRegionForThread);
OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread);
OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
inputIt.GoToBegin();
jointIt.GoToBegin();
rangeIt.GoToBegin();
spatialIt.GoToBegin();
metricIt.GoToBegin();
......@@ -443,39 +478,21 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
unsigned int iteration = 0;
// Pixel in the joint spatial-range domain
RealVector jointPixel;
// Mean shift vector, updating the joint pixel at each iteration
RealVector meanShiftVector;
jointPixel.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
while (!inputIt.IsAtEnd())
while (!jointIt.IsAtEnd())
{
bool hasConverged = false;
InputIndexType index = inputIt.GetIndex();
inputPixel = inputIt.Get();
rangePixel = rangeIt.Get();
spatialPixel = spatialIt.Get();
metricPixel = metricIt.Get();
// Initialize pixel in the joint spatial-range domain
for(unsigned int comp = 0; comp < ImageDimension; ++comp)
{
jointPixel.SetElement(comp, index[comp]);
}
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; ++comp)
{
jointPixel.SetElement(ImageDimension+comp, inputPixel[comp]);
}
jointPixel = jointIt.Get();
iteration = 0;
while ((iteration < m_MaxIterationNumber) && (!hasConverged))
......@@ -483,8 +500,19 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
double meanShiftVectorSqNorm;
//Calculate meanShiftVector
this->CalculateMeanShiftVector(input, jointPixel, outputRegionForThread, meanShiftVector);
meanShiftVectorSqNorm = meanShiftVector.GetSquaredNorm();
this->CalculateMeanShiftVector(jointImage, jointPixel, outputRegionForThread, meanShiftVector);
// Compute mean shift vector squared norm
meanShiftVectorSqNorm = 0;
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
meanShiftVectorSqNorm += meanShiftVector[comp] * meanShiftVector[comp] * m_SpatialBandwidth*m_SpatialBandwidth;
}
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
meanShiftVectorSqNorm += meanShiftVector[ImageDimension + comp] * meanShiftVector[ImageDimension + comp] * m_RangeBandwidth*m_RangeBandwidth;
}
jointPixel += meanShiftVector;
//TODO replace SSD Test with templated metric
......@@ -494,11 +522,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
rangePixel[comp] = jointPixel[ImageDimension + comp];
rangePixel[comp] = jointPixel[ImageDimension + comp] * m_RangeBandwidth;
}
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
spatialPixel[comp] = jointPixel[comp];
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth;
}
for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++)
......@@ -513,7 +541,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
iterationPixel = iteration;
iterationIt.Set(iterationPixel);
++inputIt;
++jointIt;
++rangeIt;
++spatialIt;
++metricIt;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment