diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.h b/Code/BasicFilters/otbMeanShiftImageFilter2.h index 203e1ac4947187c6ffd7eb297bc40039bd8e5447..878a5d8db5f6f1c73e3aaf933ffdb351fdf733e7 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.h +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.h @@ -278,9 +278,9 @@ protected: /** PrintSelf method */ virtual void PrintSelf(std::ostream& os, itk::Indent indent) const; - virtual void CalculateMeanShiftVector(typename RealVectorImageType::Pointer inputImagePtr, - RealVector jointPixel, const OutputRegionType& outputRegion, - RealVector & meanShiftVector); + virtual void CalculateMeanShiftVector(const typename RealVectorImageType::Pointer inputImagePtr, + const RealVector& jointPixel, const OutputRegionType& outputRegion, + RealVector& meanShiftVector); private: MeanShiftImageFilter2(const Self &); //purposely not implemented diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.txx b/Code/BasicFilters/otbMeanShiftImageFilter2.txx index 14a2905def50b84215e1fc48b45dc943169ccec4..1287f72723e9b0a23ae1b9fed223140985067fc3 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.txx +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.txx @@ -228,10 +228,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm // region for outHDispPtr) RegionType outputRequestedRegion = outRangePtr->GetRequestedRegion(); - // spatial and range radius may differ, padding must be done with the largest. - //unsigned int largestRadius= this->GetLargestRadius(); - // SHE: commented out, only the spatial radius has an effect on the input region size - //InputSizeType largestRadius= this->GetLargestRadius(); // Pad by the appropriate radius RegionType inputRequestedRegion = outputRequestedRegion; @@ -271,7 +267,7 @@ void MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> ::BeforeThreadedGenerateData() { - typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType; + // typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType; typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType; OutputMetricImagePointerType outMetricPtr = this->GetMetricOutput(); @@ -360,7 +356,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage> void MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> -::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector) +::CalculateMeanShiftVector(const typename RealVectorImageType::Pointer jointImage, const RealVector& jointPixel, const OutputRegionType& outputRegion, RealVector& meanShiftVector) { unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel; RealVector jointNeighbor; @@ -386,12 +382,13 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp]) + 1); } - neighborhoodRegion.SetIndex(regionIndex); // TODO Handle region borders - neighborhoodRegion.SetSize(regionSize); //TODO Add +1 for each dimension + neighborhoodRegion.SetIndex(regionIndex); + neighborhoodRegion.SetSize(regionSize); + // An iterator on the neighborhood of the current pixel (in joint // spatial-range domain) - itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> it(jointImage, neighborhoodRegion); + itk::ImageRegionConstIterator<RealVectorImageType> it(jointImage, neighborhoodRegion); it.GoToBegin(); while(!it.IsAtEnd()) @@ -415,7 +412,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm // TODO : replace by the templated kernel weight = (norm2 <= 1.0)? 1.0 : 0.0; - /* // The following code is an alternative way to compute norm2 and weight // It separates the norms of spatial and range elements @@ -462,8 +458,10 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm if(weightSum > 0) { - meanShiftVector /= weightSum; - meanShiftVector -= jointPixel; + for(unsigned int comp = 0; comp < jointDimension; comp++) + { + meanShiftVector[comp] = meanShiftVector[comp] / weightSum - jointPixel[comp]; + } } else meanShiftVector.Fill(0); @@ -493,20 +491,35 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType; typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType; - typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType; typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType; + unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel; + typename OutputImageType::PixelType rangePixel; + rangePixel.SetSize(m_NumberOfComponentsPerPixel); + typename OutputSpatialImageType::PixelType spatialPixel; + spatialPixel.SetSize(ImageDimension); + typename OutputMetricImageType::PixelType metricPixel; + metricPixel.SetSize(jointDimension); + typename OutputIterationImageType::PixelType iterationPixel; + metricPixel.SetSize(1); InputIndexType index; // Pixel in the joint spatial-range domain RealVector jointPixel; + RealVector bandwidth; + bandwidth.SetSize(jointDimension); + for (unsigned int comp = 0; comp < ImageDimension; comp++) bandwidth[comp] = m_SpatialBandwidth; + for (unsigned int comp = ImageDimension; comp < jointDimension; comp++) bandwidth[comp] = m_RangeBandwidth; + + + itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); RegionType requestedRegion; @@ -531,8 +544,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm // Mean shift vector, updating the joint pixel at each iteration RealVector meanShiftVector; - - meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); + meanShiftVector.SetSize(jointDimension); // Variables used by mode search optimization // List of indices where the current pixel passes through @@ -548,14 +560,8 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm { bool hasConverged = false; - - rangePixel = rangeIt.Get(); - spatialPixel = spatialIt.Get(); - metricPixel = metricIt.Get(); - jointPixel = jointIt.Get(); - // index of the currently processed output pixel InputIndexType currentIndex; for (unsigned int comp = 0; comp < ImageDimension; comp++) @@ -583,13 +589,13 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm // but not 2 (pixel in current search path), and pixel has actually moved // from its initial position, and pixel candidate is inside the output // region, then perform optimization tasks - if (m_ModeTable->GetPixel(modeCandidate) != 2 && modeCandidate != currentIndex && outputRegionForThread.IsInside(modeCandidate)) + if (modeCandidate != currentIndex && m_ModeTable->GetPixel(modeCandidate) != 2 && outputRegionForThread.IsInside(modeCandidate)) { // Obtain the data point to see if it close to jointPixel RealVector candidatePixel; RealType diff = 0; candidatePixel = m_JointImage->GetPixel(modeCandidate); - for (unsigned int comp = ImageDimension; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++) + for (unsigned int comp = ImageDimension; comp < jointDimension; comp++) { RealType d; d = candidatePixel[comp] - jointPixel[comp]; @@ -630,19 +636,17 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector); // Compute mean shift vector squared norm (not normalized by bandwidth) + // and add mean shift vector to current joint pixel double meanShiftVectorSqNorm; meanShiftVectorSqNorm = 0; - for(unsigned int comp = 0; comp < ImageDimension; comp++) - { - meanShiftVectorSqNorm += meanShiftVector[comp] * meanShiftVector[comp] * m_SpatialBandwidth*m_SpatialBandwidth; - } - for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) + for (unsigned int comp = 0; comp < jointDimension; comp++) { - meanShiftVectorSqNorm += meanShiftVector[ImageDimension + comp] * meanShiftVector[ImageDimension + comp] * m_RangeBandwidth*m_RangeBandwidth; + double v; + v = meanShiftVector[comp] * bandwidth[comp]; + meanShiftVectorSqNorm += v*v; + jointPixel[comp] += meanShiftVector[comp]; } - jointPixel += meanShiftVector; - //TODO replace SSD Test with templated metric hasConverged = meanShiftVectorSqNorm < m_Threshold; iteration++; @@ -658,7 +662,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp]; } - for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++) + for(unsigned int comp = 0; comp < jointDimension; comp++) { metricPixel[comp] = meanShiftVector[comp] * meanShiftVector[comp]; }