Commit 2febfd3d authored by Sebastien Harasse's avatar Sebastien Harasse

ENH: Mean shift. Implemented mode search optimization

parent 50dc551c
......@@ -189,8 +189,8 @@ public:
itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension);
typedef itk::VariableLengthVector<RealType> RealVector;
typedef otb::VectorImage<RealType, InputImageType::ImageDimension> RealVectorImageType;
typedef otb::Image<unsigned short, InputImageType::ImageDimension> ModeTableImageType;
/** Setters / Getters */
itkSetMacro(SpatialBandwidth, RealType);
......@@ -294,6 +294,14 @@ private:
/** Input data in the joint spatial-range domain, scaled by the bandwidths */
typename RealVectorImageType::Pointer m_JointImage;
/** Image to store the status at each pixel:
* 0 : no mode has been found yet
* 1 : a mode has been assigned to this pixel
* 2 : pixel is in the path of the currently processed pixel and a mode will
* be assigned to it
*/
typename ModeTableImageType::Pointer m_modeTable;
};
} // end namespace otb
......
......@@ -20,7 +20,6 @@
#define __otbMeanShiftImageFilter2_txx
#include "otbMeanShiftImageFilter2.h"
#include "itkImageRegionConstIteratorWithIndex.h"
#include "itkImageRegionIterator.h"
#include "otbUnaryFunctorWithIndexWithOutputSizeImageFilter.h"
......@@ -28,6 +27,8 @@
#include "itkProgressReporter.h"
#define MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION
namespace otb
{
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
......@@ -339,6 +340,18 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
++jointIt;
}
*/
#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION
// Image to store the status at each pixel:
// 0 : no mode has been found yet
// 1 : a mode has been assigned to this pixel
// 2 : a mode will be assigned to this pixel
m_modeTable = ModeTableImageType::New();
m_modeTable->SetRegions(inputPtr->GetRequestedRegion());
m_modeTable->Allocate();
m_modeTable->FillBuffer(0);
#endif
}
......@@ -497,6 +510,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
RegionType requestedRegion;
requestedRegion = input->GetRequestedRegion();
typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
......@@ -518,25 +532,96 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION
// List of indices where the current pixel passes through
std::vector<InputIndexType> pointList(m_MaxIterationNumber);
// Number of points currently in the pointList
unsigned int pointCount;
// Number of times an already processed candidate pixel is encountered, resulting in no
// further computation (Used for statistics only)
unsigned int numBreaks = 0;
#endif
while (!jointIt.IsAtEnd())
{
bool hasConverged = false;
rangePixel = rangeIt.Get();
spatialPixel = spatialIt.Get();
metricPixel = metricIt.Get();
jointPixel = jointIt.Get();
#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION
// index of the currently processed output pixel
InputIndexType currentIndex;
// index of the current pixel updated during the mean shift loop
InputIndexType modeCandidate;
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
currentIndex[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
}
pointCount = 0;
#endif
iteration = 0;
while ((iteration < m_MaxIterationNumber) && (!hasConverged))
{
double meanShiftVectorSqNorm;
#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION
// Find index of the pixel closest to the current jointPixel (not normalized by bandwidth)
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
modeCandidate[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
}
// Check status of candidate mode
if (m_modeTable->GetPixel(modeCandidate) != 2 && modeCandidate != currentIndex && outputRegionForThread.IsInside(modeCandidate))
{
// Obtain the data point to see if it close to jointPixel
RealVector candidatePixel;
RealType diff = 0;
candidatePixel = m_JointImage->GetPixel(modeCandidate);
for (unsigned int comp = ImageDimension; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++)
{
RealType d;
d = candidatePixel[comp] - jointPixel[comp];
diff += d*d;
}
if (diff < 0.5) // Spectral value is close enough
{
// if no mode has been associated to the candidate pixel then
// associate it to the upcoming mode
if( m_modeTable->GetPixel(modeCandidate) == 0)
{
pointList[pointCount++] = modeCandidate;
m_modeTable->SetPixel(modeCandidate, 2);
} else // == 1
{
// the candidate pixel has already been assigned to a mode
// Assign the same value
rangePixel = rangeOutput->GetPixel(modeCandidate);
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
jointPixel[ImageDimension + comp] = rangePixel[comp] / m_RangeBandwidth;
}
// Update the mode table because pixel will be assigned just now
m_modeTable->SetPixel(currentIndex, 2); // Note: in multithreading, = 1 would
// not be safe
// bypass further calculation
numBreaks++;
break;
}
}
}
#endif
//Calculate meanShiftVector
this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector);
// Compute mean shift vector squared norm
// Compute mean shift vector squared norm (not normalized by bandwidth)
double meanShiftVectorSqNorm;
meanShiftVectorSqNorm = 0;
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
......@@ -558,6 +643,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
rangePixel[comp] = jointPixel[ImageDimension + comp] * m_RangeBandwidth;
}
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth;
......@@ -575,6 +661,18 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
iterationPixel = iteration;
iterationIt.Set(iterationPixel);
#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION
// Update the mode table now that the current pixel has been assigned
m_modeTable->SetPixel(currentIndex, 1);
// Also assign all points in the list to the same mode
for(unsigned int i = 0; i < pointCount; i++)
{
rangeOutput->SetPixel(pointList[i], rangePixel);
m_modeTable->SetPixel(pointList[i], 1);
}
#endif
++jointIt;
++rangeIt;
++spatialIt;
......@@ -584,6 +682,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
progress.CompletedPixel();
}
// std::cout << "numBreaks: " << numBreaks << " Break ratio: " << numBreaks / (RealType)outputRegionForThread.GetNumberOfPixels() << std::endl;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment