diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.h b/Code/BasicFilters/otbMeanShiftImageFilter2.h index 9019c798fb6e1214063ebc5141bc3fe4da44e89e..658c1d9d39065e30cbae89497a18d555d72c346e 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.h +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.h @@ -244,6 +244,10 @@ private: KernelType m_RangeKernel; unsigned int m_NumberOfComponentsPerPixel; + + /** Input data in the joint spatial-range domain, scaled by the bandwidths */ + typename RealVectorImageType::Pointer m_JointImage; + }; } // end namespace otb diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.txx b/Code/BasicFilters/otbMeanShiftImageFilter2.txx index a2f85d1f57d3d663322e013950adb9e1a99ac1ed..b99a9463f88269b591b607965909eb50774e489a 100644 --- a/Code/BasicFilters/otbMeanShiftImageFilter2.txx +++ b/Code/BasicFilters/otbMeanShiftImageFilter2.txx @@ -273,10 +273,18 @@ void MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> ::BeforeThreadedGenerateData() { + typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType; + typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType; TOutputMetricImage * outMetricPtr = this->GetMetricOutput(); TOutputImage * outSpatialPtr = this->GetSpatialOutput(); TOutputImage * outRangePtr = this->GetRangeOutput(); + typename InputImageType::ConstPointer inputPtr = this->GetInput(); + + InputIndexType index; + + typename InputImageType::PixelType inputPixel; + RealVector jointPixel; m_SpatialKernel.SetBandwidth(m_SpatialBandwidth); m_RangeKernel.SetBandwidth(m_RangeBandwidth); @@ -284,6 +292,40 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm m_SpatialRadius.Fill(m_SpatialKernel.GetRadius()); m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel(); + + // Allocate the joint domain image + m_JointImage = RealVectorImageType::New(); + m_JointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel); + m_JointImage->SetRegions(inputPtr->GetRequestedRegion()); + m_JointImage->Allocate(); + + InputIteratorWithIndexType inputIt(inputPtr, inputPtr->GetRequestedRegion()); + JointImageIteratorType jointIt(m_JointImage, inputPtr->GetRequestedRegion()); + + // Initialize the joint image with scaled values + inputIt.GoToBegin(); + jointIt.GoToBegin(); + + while (!inputIt.IsAtEnd()) + { + inputPixel = inputIt.Get(); + index = inputIt.GetIndex(); + + jointPixel = jointIt.Get(); + for(unsigned int comp = 0; comp < ImageDimension; comp++) + { + jointPixel[comp] = index[comp] / m_SpatialBandwidth; + } + for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) + { + jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth; + } + jointIt.Set(jointPixel); + + ++inputIt; + ++jointIt; + } + } @@ -293,7 +335,7 @@ void MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> ::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector) { - unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel; + unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel; RealVector jointNeighbor; RealType weightSum = 0; @@ -411,9 +453,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm // Allocate output images this->AllocateOutputs(); - RegionType inputRegionForThread; - this->CallCopyOutputRegionToInputRegion(inputRegionForThread, outputRegionForThread); - // Retrieve output images pointers typename OutputImageType::Pointer spatialOutput = this->GetSpatialOutput(); typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput(); @@ -430,7 +469,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType; - typename InputImageType::PixelType inputPixel; typename OutputImageType::PixelType rangePixel; typename OutputImageType::PixelType spatialPixel; typename OutputMetricImageType::PixelType metricPixel; @@ -443,40 +481,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); - - // Declare and allocate an image in the joint space-range domain containing - // scaled values - typename RealVectorImageType::Pointer jointImage = RealVectorImageType::New(); - jointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel); - jointImage->SetRegions(outputRegionForThread); - jointImage->Allocate(); + RegionType requestedRegion; + requestedRegion = input->GetRequestedRegion(); typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType; - JointImageIteratorType jointIt(jointImage, outputRegionForThread); - InputIteratorWithIndexType inputIt(input, inputRegionForThread); - //Initialize the joint image with scaled values - inputIt.GoToBegin(); - jointIt.GoToBegin(); - - while (!inputIt.IsAtEnd()) - { - inputPixel = inputIt.Get(); - index = inputIt.GetIndex(); - - jointPixel = jointIt.Get(); - for(unsigned int comp = 0; comp < ImageDimension; comp++) - { - jointPixel[comp] = index[comp] / m_SpatialBandwidth; - } - for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++) - { - jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth; - } - jointIt.Set(jointPixel); - - ++inputIt; - ++jointIt; - } + JointImageIteratorType jointIt(m_JointImage, outputRegionForThread); OutputIteratorType rangeIt(rangeOutput, outputRegionForThread); OutputIteratorType spatialIt(spatialOutput, outputRegionForThread); @@ -494,7 +503,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm // Mean shift vector, updating the joint pixel at each iteration RealVector meanShiftVector; - jointPixel.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel); while (!jointIt.IsAtEnd()) @@ -513,7 +521,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm double meanShiftVectorSqNorm; //Calculate meanShiftVector - this->CalculateMeanShiftVector(jointImage, jointPixel, outputRegionForThread, meanShiftVector); + this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector); // Compute mean shift vector squared norm meanShiftVectorSqNorm = 0;