diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.h b/Code/BasicFilters/otbMeanShiftImageFilter2.h index 055291b0a1e6104c800082a2d342105980e0d2bf..bed4761b761100c20aa41bcad0a2a64ae20d8a09 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.h +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.h @@ -189,8 +189,8 @@ public: itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension); typedef itk::VariableLengthVector<RealType> RealVector; - typedef otb::VectorImage<RealType, InputImageType::ImageDimension> RealVectorImageType; + typedef otb::Image<unsigned short, InputImageType::ImageDimension> ModeTableImageType; /** Setters / Getters */ itkSetMacro(SpatialBandwidth, RealType); @@ -294,6 +294,14 @@ private: /** Input data in the joint spatial-range domain, scaled by the bandwidths */ typename RealVectorImageType::Pointer m_JointImage; + /** Image to store the status at each pixel: + * 0 : no mode has been found yet + * 1 : a mode has been assigned to this pixel + * 2 : pixel is in the path of the currently processed pixel and a mode will + * be assigned to it + */ + typename ModeTableImageType::Pointer m_modeTable; + }; } // end namespace otb diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.txx b/Code/BasicFilters/otbMeanShiftImageFilter2.txx index 2bd5d7a4c3815b10a0640956c98fd87c947f31b4..cddc13e22eb1a4c60d42df97f041af84553d3096 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.txx +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.txx @@ -20,7 +20,6 @@ #define __otbMeanShiftImageFilter2_txx #include "otbMeanShiftImageFilter2.h" - #include "itkImageRegionConstIteratorWithIndex.h" #include "itkImageRegionIterator.h" #include "otbUnaryFunctorWithIndexWithOutputSizeImageFilter.h" @@ -28,6 +27,8 @@ #include "itkProgressReporter.h" +#define MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION + namespace otb { template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage> @@ -339,6 +340,18 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ++jointIt; } */ + +#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION + // Image to store the status at each pixel: + // 0 : no mode has been found yet + // 1 : a mode has been assigned to this pixel + // 2 : a mode will be assigned to this pixel + m_modeTable = ModeTableImageType::New(); + m_modeTable->SetRegions(inputPtr->GetRequestedRegion()); + m_modeTable->Allocate(); + m_modeTable->FillBuffer(0); +#endif + } @@ -497,6 +510,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm RegionType requestedRegion; requestedRegion = input->GetRequestedRegion(); + typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType; JointImageIteratorType jointIt(m_JointImage, outputRegionForThread); @@ -518,25 +532,96 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); +#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION + // List of indices where the current pixel passes through + std::vector<InputIndexType> pointList(m_MaxIterationNumber); + // Number of points currently in the pointList + unsigned int pointCount; + // Number of times an already processed candidate pixel is encountered, resulting in no + // further computation (Used for statistics only) + unsigned int numBreaks = 0; +#endif + while (!jointIt.IsAtEnd()) { bool hasConverged = false; + rangePixel = rangeIt.Get(); spatialPixel = spatialIt.Get(); metricPixel = metricIt.Get(); jointPixel = jointIt.Get(); +#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION + // index of the currently processed output pixel + InputIndexType currentIndex; + // index of the current pixel updated during the mean shift loop + InputIndexType modeCandidate; + for (unsigned int comp = 0; comp < ImageDimension; comp++) + { + currentIndex[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5; + } + pointCount = 0; + +#endif + iteration = 0; while ((iteration < m_MaxIterationNumber) && (!hasConverged)) { - double meanShiftVectorSqNorm; +#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION + // Find index of the pixel closest to the current jointPixel (not normalized by bandwidth) + for (unsigned int comp = 0; comp < ImageDimension; comp++) + { + modeCandidate[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5; + } + // Check status of candidate mode + if (m_modeTable->GetPixel(modeCandidate) != 2 && modeCandidate != currentIndex && outputRegionForThread.IsInside(modeCandidate)) + { + // Obtain the data point to see if it close to jointPixel + RealVector candidatePixel; + RealType diff = 0; + candidatePixel = m_JointImage->GetPixel(modeCandidate); + for (unsigned int comp = ImageDimension; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++) + { + RealType d; + d = candidatePixel[comp] - jointPixel[comp]; + diff += d*d; + } + + if (diff < 0.5) // Spectral value is close enough + { + // if no mode has been associated to the candidate pixel then + // associate it to the upcoming mode + if( m_modeTable->GetPixel(modeCandidate) == 0) + { + pointList[pointCount++] = modeCandidate; + m_modeTable->SetPixel(modeCandidate, 2); + } else // == 1 + { + // the candidate pixel has already been assigned to a mode + // Assign the same value + rangePixel = rangeOutput->GetPixel(modeCandidate); + for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) + { + jointPixel[ImageDimension + comp] = rangePixel[comp] / m_RangeBandwidth; + } + // Update the mode table because pixel will be assigned just now + m_modeTable->SetPixel(currentIndex, 2); // Note: in multithreading, = 1 would + // not be safe + // bypass further calculation + numBreaks++; + break; + } + } + } +#endif //Calculate meanShiftVector this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector); - // Compute mean shift vector squared norm + // Compute mean shift vector squared norm (not normalized by bandwidth) + double meanShiftVectorSqNorm; meanShiftVectorSqNorm = 0; for(unsigned int comp = 0; comp < ImageDimension; comp++) { @@ -558,6 +643,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm { rangePixel[comp] = jointPixel[ImageDimension + comp] * m_RangeBandwidth; } + for(unsigned int comp = 0; comp < ImageDimension; comp++) { spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth; @@ -575,6 +661,18 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm iterationPixel = iteration; iterationIt.Set(iterationPixel); +#ifdef MEAN_SHIFT_MODE_SEARCH_OPTIMIZATION + // Update the mode table now that the current pixel has been assigned + m_modeTable->SetPixel(currentIndex, 1); + + // Also assign all points in the list to the same mode + for(unsigned int i = 0; i < pointCount; i++) + { + rangeOutput->SetPixel(pointList[i], rangePixel); + m_modeTable->SetPixel(pointList[i], 1); + } +#endif + ++jointIt; ++rangeIt; ++spatialIt; @@ -584,6 +682,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm progress.CompletedPixel(); } + // std::cout << "numBreaks: " << numBreaks << " Break ratio: " << numBreaks / (RealType)outputRegionForThread.GetNumberOfPixels() << std::endl; }