Commit c2862e27 authored by Sebastien Harasse's avatar Sebastien Harasse

ENH: Mean shift. Various (not so effective) code optimizations. +removed outdated comments

parent 0ab9234f
......@@ -278,9 +278,9 @@ protected:
/** PrintSelf method */
virtual void PrintSelf(std::ostream& os, itk::Indent indent) const;
virtual void CalculateMeanShiftVector(typename RealVectorImageType::Pointer inputImagePtr,
RealVector jointPixel, const OutputRegionType& outputRegion,
RealVector & meanShiftVector);
virtual void CalculateMeanShiftVector(const typename RealVectorImageType::Pointer inputImagePtr,
const RealVector& jointPixel, const OutputRegionType& outputRegion,
RealVector& meanShiftVector);
private:
MeanShiftImageFilter2(const Self &); //purposely not implemented
......
......@@ -228,10 +228,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// region for outHDispPtr)
RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion();
// spatial and range radius may differ, padding must be done with the largest.
//unsigned int largestRadius= this->GetLargestRadius();
// SHE: commented out, only the spatial radius has an effect on the input region size
//InputSizeType largestRadius= this->GetLargestRadius();
// Pad by the appropriate radius
RegionType inputRequestedRegion = outputRequestedRegion;
......@@ -271,7 +267,7 @@ void
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::BeforeThreadedGenerateData()
{
typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
// typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
OutputMetricImagePointerType outMetricPtr = this->GetMetricOutput();
......@@ -360,7 +356,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
void
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
::CalculateMeanShiftVector(const typename RealVectorImageType::Pointer jointImage, const RealVector& jointPixel, const OutputRegionType& outputRegion, RealVector& meanShiftVector)
{
unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
RealVector jointNeighbor;
......@@ -386,12 +382,13 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp]) + 1);
}
neighborhoodRegion.SetIndex(regionIndex); // TODO Handle region borders
neighborhoodRegion.SetSize(regionSize); //TODO Add +1 for each dimension
neighborhoodRegion.SetIndex(regionIndex);
neighborhoodRegion.SetSize(regionSize);
// An iterator on the neighborhood of the current pixel (in joint
// spatial-range domain)
itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> it(jointImage, neighborhoodRegion);
itk::ImageRegionConstIterator<RealVectorImageType> it(jointImage, neighborhoodRegion);
it.GoToBegin();
while(!it.IsAtEnd())
......@@ -415,7 +412,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// TODO : replace by the templated kernel
weight = (norm2 <= 1.0)? 1.0 : 0.0;
/*
// The following code is an alternative way to compute norm2 and weight
// It separates the norms of spatial and range elements
......@@ -462,8 +458,10 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
if(weightSum > 0)
{
meanShiftVector /= weightSum;
meanShiftVector -= jointPixel;
for(unsigned int comp = 0; comp < jointDimension; comp++)
{
meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
}
}
else
meanShiftVector.Fill(0);
......@@ -493,20 +491,35 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType;
typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
typename OutputImageType::PixelType rangePixel;
rangePixel.SetSize(m_NumberOfComponentsPerPixel);
typename OutputSpatialImageType::PixelType spatialPixel;
spatialPixel.SetSize(ImageDimension);
typename OutputMetricImageType::PixelType metricPixel;
metricPixel.SetSize(jointDimension);
typename OutputIterationImageType::PixelType iterationPixel;
metricPixel.SetSize(1);
InputIndexType index;
// Pixel in the joint spatial-range domain
RealVector jointPixel;
RealVector bandwidth;
bandwidth.SetSize(jointDimension);
for (unsigned int comp = 0; comp < ImageDimension; comp++) bandwidth[comp] = m_SpatialBandwidth;
for (unsigned int comp = ImageDimension; comp < jointDimension; comp++) bandwidth[comp] = m_RangeBandwidth;
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
RegionType requestedRegion;
......@@ -531,8 +544,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// Mean shift vector, updating the joint pixel at each iteration
RealVector meanShiftVector;
meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
meanShiftVector.SetSize(jointDimension);
// Variables used by mode search optimization
// List of indices where the current pixel passes through
......@@ -548,14 +560,8 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
bool hasConverged = false;
rangePixel = rangeIt.Get();
spatialPixel = spatialIt.Get();
metricPixel = metricIt.Get();
jointPixel = jointIt.Get();
// index of the currently processed output pixel
InputIndexType currentIndex;
for (unsigned int comp = 0; comp < ImageDimension; comp++)
......@@ -583,13 +589,13 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// but not 2 (pixel in current search path), and pixel has actually moved
// from its initial position, and pixel candidate is inside the output
// region, then perform optimization tasks
if (m_ModeTable->GetPixel(modeCandidate) != 2 && modeCandidate != currentIndex && outputRegionForThread.IsInside(modeCandidate))
if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2 && outputRegionForThread.IsInside(modeCandidate))
{
// Obtain the data point to see if it close to jointPixel
RealVector candidatePixel;
RealType diff = 0;
candidatePixel = m_JointImage->GetPixel(modeCandidate);
for (unsigned int comp = ImageDimension; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++)
for (unsigned int comp = ImageDimension; comp < jointDimension; comp++)
{
RealType d;
d = candidatePixel[comp] - jointPixel[comp];
......@@ -630,19 +636,17 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector);
// Compute mean shift vector squared norm (not normalized by bandwidth)
// and add mean shift vector to current joint pixel
double meanShiftVectorSqNorm;
meanShiftVectorSqNorm = 0;
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
meanShiftVectorSqNorm += meanShiftVector[comp] * meanShiftVector[comp] * m_SpatialBandwidth*m_SpatialBandwidth;
}
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
for (unsigned int comp = 0; comp < jointDimension; comp++)
{
meanShiftVectorSqNorm += meanShiftVector[ImageDimension + comp] * meanShiftVector[ImageDimension + comp] * m_RangeBandwidth*m_RangeBandwidth;
double v;
v = meanShiftVector[comp] * bandwidth[comp];
meanShiftVectorSqNorm += v*v;
jointPixel[comp] += meanShiftVector[comp];
}
jointPixel += meanShiftVector;
//TODO replace SSD Test with templated metric
hasConverged = meanShiftVectorSqNorm < m_Threshold;
iteration++;
......@@ -658,7 +662,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp];
}
for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++)
for(unsigned int comp = 0; comp < jointDimension; comp++)
{
metricPixel[comp] = meanShiftVector[comp] * meanShiftVector[comp];
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment