Skip to content
Snippets Groups Projects
Commit 5277cd3f authored by Sebastien Harasse's avatar Sebastien Harasse
Browse files

ENH: Mean shift: Algorithmic optimization (Bucket image)

parent 6abca16c
No related branches found
No related tags found
No related merge requests found
......@@ -22,6 +22,7 @@
#include "otbVectorImage.h"
#include "itkImageToImageFilter.h"
#include "itkImageRegionConstIterator.h"
#include "itkImageRegionConstIteratorWithIndex.h"
#include <vcl_algorithm.h>
......@@ -159,6 +160,174 @@ public:
unsigned int m_NumberOfComponentsPerPixel;
};
template <class TImage>
class BucketImage
{
public:
typedef TImage ImageType;
typedef typename ImageType::ConstPointer ImageConstPointerType;
typedef typename ImageType::PixelType PixelType;
typedef typename ImageType::InternalPixelType InternalPixelType;
typedef typename ImageType::RegionType RegionType;
typedef typename ImageType::IndexType IndexType;
typedef double RealType;
static const unsigned int ImageDimension = ImageType::ImageDimension;
/** The bucket image has dimension N+1 (ie. usually 3D for most images) */
typedef std::vector<typename ImageType::SizeType::SizeValueType> BucketImageSizeType;
//typedef std::vector<typename ImageType::IndexType::IndexValueType> BucketImageIndexType;
typedef std::vector<long> BucketImageIndexType;
/** pixel buckets typedefs and declarations */
typedef const typename ImageType::InternalPixelType * ImageDataPointerType;
typedef std::vector<ImageDataPointerType> BucketType;
typedef std::vector<BucketType> BucketListType;
BucketImage() {}
BucketImage(ImageConstPointerType image, const RegionType & region, RealType spatialRadius, RealType rangeRadius, unsigned int spectralCoordinate)
{
m_Image = image;
m_Region = region;
m_SpatialRadius = spatialRadius;
m_RangeRadius = rangeRadius;
m_SpectralCoordinate = spectralCoordinate;
// Find max and min of the used spectral band
itk::ImageRegionConstIterator<ImageType> inputIt(m_Image, m_Region);
inputIt.GoToBegin();
InternalPixelType minValue = inputIt.Get()[spectralCoordinate];
InternalPixelType maxValue = minValue;
++inputIt;
while( !inputIt.IsAtEnd() )
{
const PixelType &p = inputIt.Get();
minValue = vcl_min(minValue, p[m_SpectralCoordinate]);
maxValue = vcl_max(maxValue, p[m_SpectralCoordinate]);
++inputIt;
}
m_MinValue = minValue;
m_MaxValue = maxValue;
std::cout << "min: " << m_MinValue << ", max: " << m_MaxValue << std::endl;
// Compute bucket image dimensions
m_DimensionVector.resize(ImageDimension+1);
for(unsigned int dim = 0; dim < ImageDimension; ++dim)
{
m_DimensionVector[dim] = m_Region.GetSize()[dim] / m_SpatialRadius + 1;
}
m_DimensionVector[ImageDimension] = (unsigned int)((maxValue - minValue + m_RangeRadius) / m_RangeRadius);
std::cout << "m_DimensionVector: " << m_DimensionVector[0] << ", "<<m_DimensionVector[1] << ", "<<m_DimensionVector[2] << std::endl;
unsigned int numBuckets = m_DimensionVector[0];
for(unsigned int dim = 1; dim <= ImageDimension; ++dim)
numBuckets *= m_DimensionVector[dim];
m_BucketList.resize(numBuckets);
// Build buckets
BucketImageIndexType bucketIndex; //(ImageDimension+1);
itk::ImageRegionConstIteratorWithIndex<ImageType> it(m_Image, m_Region);
it.GoToBegin();
// this iterator is only used to get the pixel data pointer
FastImageRegionConstIterator<ImageType> fastIt(m_Image, m_Region);
fastIt.GoToBegin();
while( !it.IsAtEnd() )
{
const IndexType & index = it.GetIndex();
const PixelType & pixel = it.Get();
// Find which bucket this pixel belongs to
bucketIndex = GetBucketIndex(pixel, index);
unsigned int bucketListIndex = BucketIndexToBucketListIndex(bucketIndex);
assert(bucketListIndex < numBuckets);
m_BucketList[bucketListIndex].push_back(fastIt.GetPixelPointer());
++it;
++fastIt;
}
}
~BucketImage() {}
BucketImageIndexType GetBucketIndex(const PixelType & pixel, const IndexType & index)
{
BucketImageIndexType bucketIndex(ImageDimension+1);
for(unsigned int dim = 0; dim < ImageDimension; ++dim)
{
bucketIndex[dim] = (index[dim] - m_Region.GetIndex()[dim]) / m_SpatialRadius;
}
bucketIndex[ImageDimension] = (pixel[m_SpectralCoordinate] - m_MinValue) / m_RangeRadius;
return bucketIndex;
}
unsigned int BucketIndexToBucketListIndex(const BucketImageIndexType & bucketIndex)
{
unsigned int bucketListIndex = bucketIndex[0];
for(unsigned int dim = 1; dim <= ImageDimension; ++dim)
{
bucketListIndex = bucketListIndex * m_DimensionVector[dim] + bucketIndex[dim];
}
return bucketListIndex;
}
std::vector<unsigned int> GetNeighborhoodBucketListIndices(const BucketImageIndexType & bucketIndex)
{
std::vector<unsigned int> indices;
BucketImageIndexType neighborIndex;
indices.push_back(BucketIndexToBucketListIndex(bucketIndex));
for(unsigned int dim = 0; dim <= ImageDimension; ++dim)
{
if(bucketIndex[dim] > 0)
{
neighborIndex = bucketIndex;
neighborIndex[dim]--;
indices.push_back(BucketIndexToBucketListIndex(neighborIndex));
}
if(static_cast<typename ImageType::SizeType::SizeValueType>(bucketIndex[dim]) < m_DimensionVector[dim]-1)
{
neighborIndex = bucketIndex;
neighborIndex[dim]++;
indices.push_back(BucketIndexToBucketListIndex(neighborIndex));
}
}
return indices;
}
const BucketType & GetBucket(unsigned int index)
{
return m_BucketList[index];
}
private:
/** Input image */
ImageConstPointerType m_Image;
/** Processed region */
RegionType m_Region;
/** Spatial radius of one bucket of pixels */
RealType m_SpatialRadius;
/** Range radius (at a single dimension) of one bucket of pixels */
RealType m_RangeRadius;
/** pixels are separated in buckets depending on their spatial position and
also their value at one coordinate */
unsigned int m_SpectralCoordinate;
/** Min and Max of selected spectral coordinate */
InternalPixelType m_MinValue;
InternalPixelType m_MaxValue;
/** the buckets are stored in this list */
BucketListType m_BucketList;
/** This vector holds the dimensions of the 3D (ND?) bucket image */
BucketImageSizeType m_DimensionVector;
};
/** \class MeanShiftImageFilter2
*
*
......@@ -285,6 +454,11 @@ public:
itkSetMacro(ModeSearchOptimization, bool);
itkGetConstMacro(ModeSearchOptimization, bool);
/** Toggle bucket optimization, which is enabled by default.
*/
itkSetMacro(BucketOptimization, bool);
itkGetConstMacro(BucketOptimization, bool);
/** Returns the const spatial image output,spatial image output is a displacement map (pixel position after convergence minus pixel index) */
const OutputSpatialImageType * GetSpatialOutput() const;
/** Returns the const spectral image output */
......@@ -346,7 +520,7 @@ protected:
virtual void CalculateMeanShiftVector(const typename RealVectorImageType::Pointer inputImagePtr,
const RealVector& jointPixel, const OutputRegionType& outputRegion,
RealVector& meanShiftVector);
virtual void CalculateMeanShiftVectorBucket(const RealVector& jointPixel, RealVector& meanShiftVector);
private:
MeanShiftImageFilter2(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......@@ -386,12 +560,18 @@ private:
/** Boolean to enable mode search optimization */
bool m_ModeSearchOptimization;
/** Boolean to enable bucket optimization */
bool m_BucketOptimization;
/** Mode counters (local to each thread) */
itk::VariableLengthVector<LabelType> m_NumLabels;
/** Number of bits used to represent the threadId in the most significant bits
of labels */
unsigned int m_ThreadIdNumberOfBits;
typedef BucketImage<RealVectorImageType> BucketImageType;
BucketImageType m_BucketImage;
};
} // end namespace otb
......
......@@ -39,6 +39,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TOutputIterationImage>
m_RangeBandwidth=16.;
m_Threshold=1e-3;
m_ModeSearchOptimization = true;
m_BucketOptimization = true;
this->SetNumberOfOutputs(4);
this->SetNthOutput(0, OutputImageType::New());
......@@ -291,6 +292,15 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TOutputIterationImage>
jointImageFunctor->Update();
m_JointImage = jointImageFunctor->GetOutput();
if(m_BucketOptimization)
{
// Create bucket image
// Note: because values in the input m_JointImage are normalized, the
// rangeRadius argument is just 1
m_BucketImage = BucketImageType(static_cast<typename RealVectorImageType::ConstPointer>(m_JointImage),
m_JointImage->GetRequestedRegion(),
m_Kernel.GetRadius(m_SpatialBandwidth), 1, ImageDimension);
}
/*
// Allocate the joint domain image
m_JointImage = RealVectorImageType::New();
......@@ -479,6 +489,79 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TOutputIterationImage>
meanShiftVector.Fill(0);
}
// Calculates the mean shift vector at the position given by jointPixel
template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TOutputIterationImage>
::CalculateMeanShiftVectorBucket(const RealVector& jointPixel, RealVector& meanShiftVector)
{
unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
RealVector jointNeighbor;
RealType weightSum = 0;
meanShiftVector.Fill(0.);
jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
InputIndexType index;
for(unsigned int dim = 0; dim < ImageDimension; ++dim)
{
index[dim] = jointPixel[dim] * m_SpatialBandwidth + 0.5;
}
std::vector<unsigned int> neighborBuckets;
neighborBuckets = m_BucketImage.GetNeighborhoodBucketListIndices(m_BucketImage.GetBucketIndex(jointPixel, index));
while(!neighborBuckets.empty())
{
const typename BucketImageType::BucketType & bucket = m_BucketImage.GetBucket(neighborBuckets.back());
neighborBuckets.pop_back();
if(bucket.empty()) continue;
typename BucketImageType::BucketType::const_iterator it = bucket.begin();
while(it != bucket.end())
{
RealType norm2;
RealType weight;
//std::cout << bucket.size() << std::endl;
jointNeighbor.SetData(const_cast<RealType*>(*it));
// Compute the squared norm of the difference
// This is the L2 norm, TODO: replace by the templated norm
norm2 = 0;
for(unsigned int comp = 0; comp < jointDimension; comp++)
{
RealType d;
d = jointNeighbor[comp] - jointPixel[comp];
norm2 += d*d;
}
// Compute pixel weight from kernel
weight = m_Kernel(norm2);
// Update sum of weights
weightSum += weight;
// Update mean shift vector
for(unsigned int comp = 0; comp < jointDimension; comp++)
{
meanShiftVector[comp] += weight * jointNeighbor[comp];
}
++it;
}
}
if(weightSum > 0)
{
for(unsigned int comp = 0; comp < jointDimension; comp++)
{
meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp];
}
}
else
meanShiftVector.Fill(0);
}
template <class TInputImage, class TOutputImage, class TKernel, class TOutputIterationImage>
void
......@@ -661,7 +744,14 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TOutputIterationImage>
} // end if (m_ModeSearchOptimization)
//Calculate meanShiftVector
this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector);
if(m_BucketOptimization)
{
this->CalculateMeanShiftVectorBucket(jointPixel, meanShiftVector);
}
else
{
this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector);
}
// Compute mean shift vector squared norm (not normalized by bandwidth)
// and add mean shift vector to current joint pixel
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment