Skip to content
Snippets Groups Projects
Commit fc0fffe4 authored by Sebastien Harasse's avatar Sebastien Harasse
Browse files

REFAC: Mean shift refactoring (Work in progress).

Changed pixel neighborhood array to image region iterator.
Updated image border handling.
Used variable length vectors to represent joint pixel.
parent 840efdbb
No related branches found
No related tags found
No related merge requests found
......@@ -113,12 +113,14 @@ public:
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef double RealType;
/** Type macro */
itkTypeMacro(MeanShiftImageFilter2, ImageToImageFilter);
itkNewMacro(Self);
/** Template parameters typedefs */
typedef double RealType;
typedef TInputImage InputImageType;
typedef typename InputImageType::Pointer InputImagePointerType;
typedef typename InputImageType::PixelType InputPixelType;
......@@ -143,6 +145,10 @@ public:
typedef TKernel KernelType;
itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension);
typedef itk::VariableLengthVector<RealType> RealVector;
/** Setters / Getters */
itkSetMacro(SpatialBandwidth, RealType);
itkGetMacro(SpatialBandwidth, RealType);
......@@ -218,7 +224,7 @@ protected:
//virtual void GetNeighborhood(PointType latticePosition);
virtual void GetNeighborhood(OutputPixelType **neighborhood,PointType latticePosition);
virtual OutputMetricPixelType CalculateMeanShiftVector(OutputPixelType *neighbothood,OutputPixelType spatialPixel,OutputPixelType rangePixel);
virtual void CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector, OutputPixelType *neighborhood, OutputPixelType spatialPixel,OutputPixelType rangePixel);
// virtual void CreateUniformKernel();
private:
......@@ -259,8 +265,7 @@ private:
bool m_NeighborhoodHasTobeUpdated;
unsigned int m_NumberOfSpatialComponents;
unsigned int m_NumberOfComponentsPerPixel;
};
......
......@@ -50,8 +50,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
m_Threshold=1e-3;
m_NumberOfSpatialComponents=TInputImage::ImageDimension; //image lattice
m_NeighborhoodHasTobeUpdated = true;
this->SetNumberOfOutputs(4);
this->SetNthOutput(0, OutputImageType::New());
......@@ -323,18 +321,20 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
Superclass::GenerateOutputInformation();
unsigned int numberOfComponents= this->GetInput()->GetNumberOfComponentsPerPixel();
m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
if (this->GetSpatialOutput())
{
this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(m_NumberOfSpatialComponents); // image lattice
this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension); // image lattice
}
if (this->GetSpatialOutput())
{
this->GetRangeOutput()->SetNumberOfComponentsPerPixel(numberOfComponents);
this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
}
if (this->GetMetricOutput())
{
this->GetMetricOutput()->SetNumberOfComponentsPerPixel(numberOfComponents+m_NumberOfSpatialComponents); // Spectral Part + lattice
this->GetMetricOutput()->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel); // Spectral Part + lattice
}
}
......@@ -420,120 +420,135 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
m_SpatialRadius.Fill(m_SpatialKernel.GetRadius());
m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
}
// returns input spatial neighborhood, range, and binary map for boundaries
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputMetricPixelType
void //typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputMetricPixelType
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::CalculateMeanShiftVector(OutputPixelType *neighborhood,OutputPixelType spatialPixel,OutputPixelType rangePixel)
::CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector, OutputPixelType *neighborhood, OutputPixelType spatialPixel,OutputPixelType rangePixel)
{
//std::cout<<"calculate mean shift vector"<<std::endl;
OutputMetricPixelType meanShiftVector;
OutputMetricPixelType weightingMeanShiftVector;
// Kernel*Input //
InputSizeType kernelSize = m_SpatialRadius;
// OutputMetricPixelType meanShiftVector;
RealVector weightingMeanShiftVector;
unsigned int numberOfPixels= kernelSize[0]*kernelSize[1];
//std::cout<<"number of pix "<<numberOfPixels<<std::endl;
unsigned int spatialNumberOfComponents = spatialPixel.Size();
unsigned int rangeNumberOfComponents = rangePixel.Size();
unsigned int numberOfComponents = spatialNumberOfComponents +rangeNumberOfComponents;
// unsigned int numberOfComponents = spatialNumberOfComponents +rangeNumberOfComponents;
double sum=0;
meanShiftVector.SetSize(numberOfComponents);
meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
meanShiftVector.Fill(0.);
weightingMeanShiftVector.SetSize(numberOfComponents);
weightingMeanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
weightingMeanShiftVector.Fill(0.);
// use only m_Kernel : need to define a concatenate output not spatial and range
OutputPixelType *it = neighborhood;
//std::cout<<"start processing"<<std::endl;
double neighborhoodValue;
double value;
unsigned int boundaryWeightIndex=numberOfComponents;
bool isInside;
for(unsigned int y=0; y<kernelSize[1]; y++)
InputIndexType inputIndex;
InputIndexType regionIndex;
InputSizeType regionSize;
RegionType neighborhoodRegion;
// Calculates current pixel neighborhood region, restricted to the output image region
for(unsigned int comp = 0; comp < ImageDimension; ++comp)
{
for (unsigned int x = 0; x < kernelSize[0]; x++)
{
unsigned int indexRight;
inputIndex[comp] = jointPixel[comp];
isInside = true;
regionIndex[comp] = vcl_max(outputRegion.GetIndex().GetElement(comp), static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp]));
indexRight = vcl_min(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1, inputIndex[comp] + m_SpatialRadius[comp]);
double diff, el;
el = 0;
for (unsigned int comp = 0; comp < spatialNumberOfComponents; comp++)
{
neighborhoodValue = it->GetElement(comp);
el += (neighborhoodValue - spatialPixel[comp]) * (neighborhoodValue - spatialPixel[comp]);
}
diff = el / (m_SpatialBandwidth * m_SpatialBandwidth);
isInside = diff < 1.0;
if (isInside)
regionSize[comp] = indexRight - inputIndex[comp] + 1;
}
neighborhoodRegion.SetIndex(regionIndex); // TODO Handle region borders
neighborhoodRegion.SetSize(regionSize); //TODO Add +1 for each dimension
itk::ImageRegionConstIteratorWithIndex<InputImageType> it(inputImage, neighborhoodRegion);
it.GoToBegin();
while(!it.IsAtEnd())
{
isInside = true;
double diff, el;
el = 0;
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
neighborhoodValue = it.GetIndex().GetElement(comp);
el += (neighborhoodValue - jointPixel[comp]) * (neighborhoodValue - jointPixel[comp]);
}
diff = el / (m_SpatialBandwidth * m_SpatialBandwidth);
isInside = diff < 1.0;
if (isInside)
{
diff = 0;
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
diff = 0;
for (unsigned int comp = 0; comp < rangeNumberOfComponents; comp++)
{
neighborhoodValue = it->GetElement(comp + spatialNumberOfComponents);
el = (neighborhoodValue - rangePixel[comp]) / m_RangeBandwidth;
neighborhoodValue = it.Get().GetElement(comp);
el = (neighborhoodValue - jointPixel[ImageDimension + comp]) / m_RangeBandwidth;
diff += el * el;
diff += el * el;
}
}
isInside = diff < 1.0;
}
isInside = diff < 1.0;
if (it->GetElement(boundaryWeightIndex) && isInside)
if (/*it->GetElement(boundaryWeightIndex) && */ isInside)
{
for (unsigned int comp = 0; comp < spatialNumberOfComponents; comp++)
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
neighborhoodValue = it->GetElement(comp);
neighborhoodValue = it.GetIndex().GetElement(comp);
value = 1;
meanShiftVector[comp] += (neighborhoodValue);
meanShiftVector[comp] += (neighborhoodValue) * value;
weightingMeanShiftVector[comp] += value;
}
for (unsigned int comp = 0; comp < rangeNumberOfComponents; comp++)
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
neighborhoodValue = it->GetElement(comp + spatialNumberOfComponents);
//value=rangeIt->GetElement(comp);
neighborhoodValue = it.Get().GetElement(comp);
value = 1;
// meanShiftVector[spatialNumberOfComponents+comp]+=(neighborhoodValue-rangePixel[comp])*(neighborhoodValue-rangePixel[comp])*neighborhoodValue*value;
// weightingMeanShiftVector[spatialNumberOfComponents+comp]+=(neighborhoodValue-rangePixel[comp])*(neighborhoodValue-rangePixel[comp])*value;
meanShiftVector[spatialNumberOfComponents + comp] += (neighborhoodValue);
// std::cout<<"add value "<<neighborhoodValue<<std::endl;
weightingMeanShiftVector[spatialNumberOfComponents + comp] += value;
meanShiftVector[ImageDimension + comp] += (neighborhoodValue) * value;
weightingMeanShiftVector[ImageDimension + comp] += value;
}
}
++it;
}
}
for(unsigned int comp=0; comp<spatialNumberOfComponents; comp++)
for(unsigned int comp=0; comp < ImageDimension; comp++)
{
if( weightingMeanShiftVector[comp]>0)
meanShiftVector[comp]=meanShiftVector[comp]/weightingMeanShiftVector[comp]-spatialPixel[comp];
meanShiftVector[comp]=meanShiftVector[comp]/weightingMeanShiftVector[comp]-jointPixel[comp];
else
meanShiftVector[comp]=0;
}
for(unsigned int comp=0; comp<rangeNumberOfComponents; comp++)
for(unsigned int comp=0; comp<m_NumberOfComponentsPerPixel; comp++)
{
if( weightingMeanShiftVector[spatialNumberOfComponents+comp]>0)
meanShiftVector[spatialNumberOfComponents+comp]=meanShiftVector[spatialNumberOfComponents+comp]/weightingMeanShiftVector[spatialNumberOfComponents+comp]-rangePixel[comp];
if( weightingMeanShiftVector[ImageDimension + comp]>0)
meanShiftVector[ImageDimension + comp] = meanShiftVector[ImageDimension + comp] / weightingMeanShiftVector[ImageDimension + comp] - jointPixel[ImageDimension + comp];
else
meanShiftVector[spatialNumberOfComponents+comp]=0;
meanShiftVector[ImageDimension + comp] = 0;
}
// std::cout<<" mean shift vector val "<<meanShiftVector[2]<<" position "<<meanShiftVector[0]<<" "<<meanShiftVector[1]<<std::endl<<std::endl;
return meanShiftVector;
}
......@@ -607,20 +622,20 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
it->SetElement(1, pixelPos[1]);
for (unsigned int comp = 0; comp < numberOfComponents; comp++)
{
it->SetElement(comp + m_NumberOfSpatialComponents, inputPixel[comp]);
it->SetElement(comp + ImageDimension, inputPixel[comp]);
}
it->SetElement(numberOfComponents + m_NumberOfSpatialComponents, 1.);
it->SetElement(numberOfComponents + ImageDimension, 1.);
}
else
{
for (unsigned int comp = 0; comp < numberOfComponents; comp++)
{
it->SetElement(comp + m_NumberOfSpatialComponents, 0.);
it->SetElement(comp + ImageDimension, 0.);
}
it->SetElement(0, pixelPos[0]);
it->SetElement(1, pixelPos[1]);
it->SetElement(numberOfComponents + m_NumberOfSpatialComponents, 0.);
it->SetElement(numberOfComponents + ImageDimension, 0.);
}
++it;
......@@ -638,97 +653,110 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
// at the first iteration
// Allocate the output image
// Allocate output images
this->AllocateOutputs();
RegionType inputRegionForThread;
this->CallCopyOutputRegionToInputRegion(inputRegionForThread, outputRegionForThread);
// Allocate output
// Retrieve output images pointers
typename OutputImageType::Pointer spatialOutput = this->GetSpatialOutput();
typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
typename OutputMetricImageType::Pointer metricOutput = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
// Get input image pointer
typename InputImageType::ConstPointer input = this->GetInput();
// defines input and output iterators
//ypedef itk::ConstIterator<InputImageType> InputConstNeighborhoodIteratorType;
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType;
typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
typename InputImageType::PixelType inputPixel;
typename OutputImageType::PixelType rangePixel;
typename OutputImageType::PixelType spatialPixel;
typename OutputMetricImageType::PixelType metricPixel;
typename OutputIterationImageType::PixelType iterationPixel;
InputIteratorWithIndexType inputIt(input, inputRegionForThread);
itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
inputIt.GoToBegin();
InputIteratorWithIndexType inputIt(input, inputRegionForThread);
OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
OutputIteratorType spatialIt(spatialOutput, outputRegionForThread);
OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread);
OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
// fill pixel
inputIt.GoToBegin();
rangeIt.GoToBegin();
spatialIt.GoToBegin();
metricIt.GoToBegin();
iterationIt.GoToBegin();
unsigned int spatialNumberOfComponents = spatialOutput->GetNumberOfComponentsPerPixel();
unsigned int rangeNumberOfComponents = rangeOutput->GetNumberOfComponentsPerPixel();
unsigned int numberOfComponents = spatialNumberOfComponents + rangeNumberOfComponents;
//unsigned int rangeNumberOfComponents = rangeOutput->GetNumberOfComponentsPerPixel();
//unsigned int numberOfComponents = spatialNumberOfComponents + rangeNumberOfComponents;
InputSizeType kernelSize = m_SpatialRadius;
OutputPixelType *neighborhood = new OutputPixelType[kernelSize[0] * kernelSize[1]];
unsigned int iteration = 0;
// Pixel in the joint spatial-range domain
RealVector jointPixel;
// Mean shift vector, updating the joint pixel at each iteration
RealVector meanShiftVector;
jointPixel.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
while (!inputIt.IsAtEnd())
{
bool neighborhoodHasTobeUpdated = true;
InputIndexType index = inputIt.GetIndex();
inputPixel = inputIt.Get();
rangePixel = rangeIt.Get();
spatialPixel = spatialIt.Get();
metricPixel = metricIt.Get();
InputIndexType index = inputIt.GetIndex();
spatialPixel.SetElement(0, index[0]);
spatialPixel.SetElement(1, index[1]);
rangePixel = inputIt.Get();
// spatialPixel.SetElement(0, index[0]);
// spatialPixel.SetElement(1, index[1]);
// TODO change the maximum value;
bool hasConverged = false;
while ((iteration < m_MaxIterationNumber) && (!hasConverged))
// Initialize pixel in the joint spatial-range domain
for(unsigned int comp = 0; comp < ImageDimension; ++comp)
{
jointPixel.SetElement(comp, index[comp]);
}
typename OutputMetricImageType::PixelType meanShiftVector;
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; ++comp)
{
jointPixel.SetElement(ImageDimension+comp, inputPixel[comp]);
}
PointType position;
position[0] = spatialPixel[0];
position[1] = spatialPixel[1];
// use only when needed
if (neighborhoodHasTobeUpdated)
{
this->GetNeighborhood(&neighborhood, position);
neighborhoodHasTobeUpdated = false;
}
while ((iteration < m_MaxIterationNumber) && (!hasConverged))
{
//Calculate meanShiftVector
meanShiftVector = this->CalculateMeanShiftVector(neighborhood, spatialPixel, rangePixel);
this->CalculateMeanShiftVector(input, jointPixel, rangeOutput->GetLargestPossibleRegion(), meanShiftVector, neighborhood, spatialPixel, rangePixel);
double sum = 0;
for (unsigned int comp = 0; comp < spatialNumberOfComponents; comp++)
/*
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
neighborhoodHasTobeUpdated = neighborhoodHasTobeUpdated || ((vcl_floor(
spatialPixel[comp]
......@@ -740,21 +768,45 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
}
for (unsigned int comp = 0; comp < rangeNumberOfComponents; comp++)
for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
rangePixel[comp] += meanShiftVector[spatialNumberOfComponents + comp];
metricPixel[spatialNumberOfComponents + comp] = meanShiftVector[spatialNumberOfComponents + comp]
* meanShiftVector[spatialNumberOfComponents + comp];
sum += metricPixel[spatialNumberOfComponents+comp];
rangePixel[comp] += meanShiftVector[ImageDimension + comp];
metricPixel[ImageDimension + comp] = meanShiftVector[ImageDimension + comp]
* meanShiftVector[ImageDimension + comp];
sum += metricPixel[ImageDimension+comp];
}
*/
double meanShiftVectorSqNorm;
meanShiftVectorSqNorm = meanShiftVector.GetSquaredNorm();
jointPixel += meanShiftVector;
//TODO replace SSD Test with templated metric
hasConverged = sum < m_Threshold;
hasConverged = meanShiftVectorSqNorm < m_Threshold;
//hasConverged = sum < m_Threshold;
iteration++;
}
for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
{
rangePixel[comp] = jointPixel[ImageDimension + comp];
}
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
spatialPixel[comp] = jointPixel[comp];
}
for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++)
{
metricPixel[comp] = meanShiftVector[comp] * meanShiftVector[comp];
}
rangeIt.Set(rangePixel);
spatialIt.Set(spatialPixel);
metricIt.Set(metricPixel);
......@@ -763,7 +815,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
iterationIt.Set(iterationPixel);
++inputIt;
++rangeIt;
++spatialIt;
++metricIt;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment