From e12cfb6694de12a5fc33bf71b45c697f131dd3f6 Mon Sep 17 00:00:00 2001 From: Sebastien Harasse <sebastien.harasse@c-s.fr> Date: Tue, 24 Apr 2012 17:05:35 +0200 Subject: [PATCH] REFAC: Mean shift. Allocated image in joint space-range domain with scaled values. --- Code/BasicFilters/otbMeanShiftImageFilter2.h | 6 +- .../BasicFilters/otbMeanShiftImageFilter2.txx | 152 +++++++++++------- 2 files changed, 94 insertions(+), 64 deletions(-) diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.h b/Code/BasicFilters/otbMeanShiftImageFilter2.h index a654640674..2a901ac53c 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.h +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.h @@ -141,7 +141,9 @@ public: itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension); - typedef itk::VariableLengthVector<RealType> RealVector; + typedef itk::VariableLengthVector<RealType> RealVector; + + typedef itk::VectorImage<RealType, InputImageType::ImageDimension> RealVectorImageType; /** Setters / Getters */ itkSetMacro(SpatialBandwidth, RealType); @@ -215,7 +217,7 @@ protected: /** PrintSelf method */ virtual void PrintSelf(std::ostream& os, itk::Indent indent) const; - virtual void CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector); + virtual void CalculateMeanShiftVector(typename RealVectorImageType::Pointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector); private: MeanShiftImageFilter2(const Self &); //purposely not implemented diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.txx b/Code/BasicFilters/otbMeanShiftImageFilter2.txx index bcd0ce71bc..729c0b8d0c 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.txx +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.txx @@ -291,85 +291,83 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage> void MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> -::CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector) +::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector) { RealVector jointNeighbor; - jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); - RealType weightSum = 0; - InputPixelType inputPixel; InputIndexType inputIndex; InputIndexType regionIndex; InputSizeType regionSize; RegionType neighborhoodRegion; meanShiftVector.Fill(0.); + jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); // Calculates current pixel neighborhood region, restricted to the output image region for(unsigned int comp = 0; comp < ImageDimension; ++comp) { long int indexRight; - inputIndex[comp] = jointPixel[comp]; + inputIndex[comp] = jointPixel[comp] * m_SpatialBandwidth; regionIndex[comp] = vcl_max(static_cast<long int>(outputRegion.GetIndex().GetElement(comp)), static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp])); indexRight = vcl_min(static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1), static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp])); - // regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1)); - regionSize[comp] = indexRight - regionIndex[comp] + 1; + regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1)); } neighborhoodRegion.SetIndex(regionIndex); // TODO Handle region borders neighborhoodRegion.SetSize(regionSize); //TODO Add +1 for each dimension - // An iterator on the neighborhood of the current pixel - itk::ImageRegionConstIteratorWithIndex<InputImageType> it(inputImage, neighborhoodRegion); + // An iterator on the neighborhood of the current pixel (in joint + // spatial-range domain) + itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> it(jointImage, neighborhoodRegion); - //std::cout << neighborhoodRegion << std::endl; it.GoToBegin(); while(!it.IsAtEnd()) { - inputIndex = it.GetIndex(); - inputPixel = it.Get(); RealVector diff; RealType norm2; RealType weight; - // Write the current pixel of the neighborhood in the joint spatial-range domain - for (unsigned int comp = 0; comp < ImageDimension; comp++) - { - jointNeighbor[comp] = inputIndex[comp]; - } - for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) - { - jointNeighbor[ImageDimension + comp] = inputPixel[comp]; - } + jointNeighbor = it.Get(); // Calculate the squared norm of the difference diff = jointNeighbor - jointPixel; - // Scale diff vector elements by the bandwidth + // Compute the squared norm of the difference + // This is the L2 norm, TODO: replace by the templated norm + norm2 = diff.GetSquaredNorm(); + // Compute pixel weight from kernel + // TODO : replace by the templated kernel + weight = (norm2 <= 1.0)? 1.0 : 0.0; + + /* + // The following code is an alternative way to compute norm2 and weight + // It separates the norms of spatial and range elements + RealType spatialNorm2; + RealType rangeNorm2; + spatialNorm2 = 0; for (unsigned int comp = 0; comp < ImageDimension; comp++) { - diff[comp] /= m_SpatialBandwidth; - } - for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) - { - diff[ImageDimension + comp] /= m_RangeBandwidth; + spatialNorm2 += diff[comp] * diff[comp]; } - // Compute the squared norm of the difference - // This is the L_inf norm, TODO: replace by the templated norm - norm2 = 0; - for (unsigned int comp = 0; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++) + if(spatialNorm2 >= 1.0) { - norm2 += vcl_max(norm2, vcl_abs(diff[comp])); + weight = 0; } - norm2 *= norm2; + else + { + rangeNorm2 = 0; + for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) + { + rangeNorm2 += diff[ImageDimension + comp] * diff[ImageDimension + comp]; + } - // Compute pixel weight from kernel - // TODO : replace by the templated kernel - weight = (norm2 <= 1.0)? 1.0 : 0.0; + weight = (rangeNorm2 <= 1.0)? 1.0 : 0.0; + } + */ // Update sum of weights weightSum += weight; @@ -425,17 +423,54 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm typename OutputMetricImageType::PixelType metricPixel; typename OutputIterationImageType::PixelType iterationPixel; + InputIndexType index; + + // Pixel in the joint spatial-range domain + RealVector jointPixel; itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); + // Declare and allocate an image in the joint space-range domain containing + // scaled values + typename RealVectorImageType::Pointer jointImage = RealVectorImageType::New(); + jointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel); + jointImage->SetRegions(outputRegionForThread); + jointImage->Allocate(); + + typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType; + JointImageIteratorType jointIt(jointImage, outputRegionForThread); InputIteratorWithIndexType inputIt(input, inputRegionForThread); + //Initialize the joint image with scaled values + inputIt.GoToBegin(); + jointIt.GoToBegin(); + + while (!inputIt.IsAtEnd()) + { + inputPixel = inputIt.Get(); + index = inputIt.GetIndex(); + + jointPixel = jointIt.Get(); + for(unsigned int comp = 0; comp < ImageDimension; comp++) + { + jointPixel[comp] = index[comp] / m_SpatialBandwidth; + } + for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) + { + jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth; + } + jointIt.Set(jointPixel); + + ++inputIt; + ++jointIt; + } + OutputIteratorType rangeIt(rangeOutput, outputRegionForThread); OutputIteratorType spatialIt(spatialOutput, outputRegionForThread); OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread); OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread); - inputIt.GoToBegin(); + jointIt.GoToBegin(); rangeIt.GoToBegin(); spatialIt.GoToBegin(); metricIt.GoToBegin(); @@ -443,39 +478,21 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm unsigned int iteration = 0; - - // Pixel in the joint spatial-range domain - RealVector jointPixel; - // Mean shift vector, updating the joint pixel at each iteration RealVector meanShiftVector; - jointPixel.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); - - while (!inputIt.IsAtEnd()) + while (!jointIt.IsAtEnd()) { bool hasConverged = false; - InputIndexType index = inputIt.GetIndex(); - inputPixel = inputIt.Get(); rangePixel = rangeIt.Get(); spatialPixel = spatialIt.Get(); metricPixel = metricIt.Get(); - - // Initialize pixel in the joint spatial-range domain - for(unsigned int comp = 0; comp < ImageDimension; ++comp) - { - jointPixel.SetElement(comp, index[comp]); - } - - for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; ++comp) - { - jointPixel.SetElement(ImageDimension+comp, inputPixel[comp]); - } + jointPixel = jointIt.Get(); iteration = 0; while ((iteration < m_MaxIterationNumber) && (!hasConverged)) @@ -483,8 +500,19 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm double meanShiftVectorSqNorm; //Calculate meanShiftVector - this->CalculateMeanShiftVector(input, jointPixel, outputRegionForThread, meanShiftVector); - meanShiftVectorSqNorm = meanShiftVector.GetSquaredNorm(); + this->CalculateMeanShiftVector(jointImage, jointPixel, outputRegionForThread, meanShiftVector); + + // Compute mean shift vector squared norm + meanShiftVectorSqNorm = 0; + for(unsigned int comp = 0; comp < ImageDimension; comp++) + { + meanShiftVectorSqNorm += meanShiftVector[comp] * meanShiftVector[comp] * m_SpatialBandwidth*m_SpatialBandwidth; + } + for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) + { + meanShiftVectorSqNorm += meanShiftVector[ImageDimension + comp] * meanShiftVector[ImageDimension + comp] * m_RangeBandwidth*m_RangeBandwidth; + } + jointPixel += meanShiftVector; //TODO replace SSD Test with templated metric @@ -494,11 +522,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) { - rangePixel[comp] = jointPixel[ImageDimension + comp]; + rangePixel[comp] = jointPixel[ImageDimension + comp] * m_RangeBandwidth; } for(unsigned int comp = 0; comp < ImageDimension; comp++) { - spatialPixel[comp] = jointPixel[comp]; + spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth; } for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++) @@ -513,7 +541,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm iterationPixel = iteration; iterationIt.Set(iterationPixel); - ++inputIt; + ++jointIt; ++rangeIt; ++spatialIt; ++metricIt; -- GitLab