diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.h b/Code/BasicFilters/otbMeanShiftImageFilter2.h
index 9019c798fb6e1214063ebc5141bc3fe4da44e89e..658c1d9d39065e30cbae89497a18d555d72c346e 100644
--- a/Code/BasicFilters/otbMeanShiftImageFilter2.h
+++ b/Code/BasicFilters/otbMeanShiftImageFilter2.h
@@ -244,6 +244,10 @@ private:
   KernelType m_RangeKernel;
 
   unsigned int m_NumberOfComponentsPerPixel;
+
+  /** Input data in the joint spatial-range domain, scaled by the bandwidths */
+  typename RealVectorImageType::Pointer m_JointImage;
+
 };
 
 } // end namespace otb
diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.txx b/Code/BasicFilters/otbMeanShiftImageFilter2.txx
index a2f85d1f57d3d663322e013950adb9e1a99ac1ed..b99a9463f88269b591b607965909eb50774e489a 100644
--- a/Code/BasicFilters/otbMeanShiftImageFilter2.txx
+++ b/Code/BasicFilters/otbMeanShiftImageFilter2.txx
@@ -273,10 +273,18 @@ void
 MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
 ::BeforeThreadedGenerateData()
 {
+  typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
+  typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
 
   TOutputMetricImage    * outMetricPtr = this->GetMetricOutput();
   TOutputImage * outSpatialPtr   = this->GetSpatialOutput();
   TOutputImage * outRangePtr   = this->GetRangeOutput();
+  typename InputImageType::ConstPointer inputPtr = this->GetInput();
+
+  InputIndexType index;
+
+  typename InputImageType::PixelType inputPixel;
+  RealVector jointPixel;
 
   m_SpatialKernel.SetBandwidth(m_SpatialBandwidth);
   m_RangeKernel.SetBandwidth(m_RangeBandwidth);
@@ -284,6 +292,40 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
   m_SpatialRadius.Fill(m_SpatialKernel.GetRadius());
 
   m_NumberOfComponentsPerPixel = this->GetInput()->GetNumberOfComponentsPerPixel();
+
+  // Allocate the joint domain image
+  m_JointImage = RealVectorImageType::New();
+  m_JointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel);
+  m_JointImage->SetRegions(inputPtr->GetRequestedRegion());
+  m_JointImage->Allocate();
+
+  InputIteratorWithIndexType inputIt(inputPtr, inputPtr->GetRequestedRegion());
+  JointImageIteratorType jointIt(m_JointImage, inputPtr->GetRequestedRegion());
+
+  // Initialize the joint image with scaled values
+  inputIt.GoToBegin();
+  jointIt.GoToBegin();
+
+  while (!inputIt.IsAtEnd())
+    {
+    inputPixel = inputIt.Get();
+    index = inputIt.GetIndex();
+
+    jointPixel = jointIt.Get();
+    for(unsigned int comp = 0; comp < ImageDimension; comp++)
+      {
+      jointPixel[comp] = index[comp] / m_SpatialBandwidth;
+      }
+    for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
+      {
+      jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth;
+      }
+    jointIt.Set(jointPixel);
+
+    ++inputIt;
+    ++jointIt;
+    }
+
 }
 
 
@@ -293,7 +335,7 @@ void
 MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
 ::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
  {
-   unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
+  unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
   RealVector jointNeighbor;
 
   RealType weightSum = 0;
@@ -411,9 +453,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
   // Allocate output images
   this->AllocateOutputs();
 
-  RegionType inputRegionForThread;
-  this->CallCopyOutputRegionToInputRegion(inputRegionForThread, outputRegionForThread);
-
   // Retrieve output images pointers
   typename OutputImageType::Pointer spatialOutput = this->GetSpatialOutput();
   typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
@@ -430,7 +469,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
   typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
 
 
-  typename InputImageType::PixelType inputPixel;
   typename OutputImageType::PixelType rangePixel;
   typename OutputImageType::PixelType spatialPixel;
   typename OutputMetricImageType::PixelType metricPixel;
@@ -443,40 +481,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
 
   itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
 
-
-  // Declare and allocate an image in the joint space-range domain containing
-  // scaled values
-  typename RealVectorImageType::Pointer jointImage = RealVectorImageType::New();
-  jointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel);
-  jointImage->SetRegions(outputRegionForThread);
-  jointImage->Allocate();
+  RegionType requestedRegion;
+  requestedRegion = input->GetRequestedRegion();
 
   typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
-  JointImageIteratorType jointIt(jointImage, outputRegionForThread);
-  InputIteratorWithIndexType inputIt(input, inputRegionForThread);
-  //Initialize the joint image with scaled values
-  inputIt.GoToBegin();
-  jointIt.GoToBegin();
-
-  while (!inputIt.IsAtEnd())
-    {
-    inputPixel = inputIt.Get();
-    index = inputIt.GetIndex();
-
-    jointPixel = jointIt.Get();
-    for(unsigned int comp = 0; comp < ImageDimension; comp++)
-      {
-      jointPixel[comp] = index[comp] / m_SpatialBandwidth;
-      }
-    for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
-      {
-      jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth;
-      }
-    jointIt.Set(jointPixel);
-
-    ++inputIt;
-    ++jointIt;
-    }
+  JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
 
   OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
   OutputIteratorType spatialIt(spatialOutput, outputRegionForThread);
@@ -494,7 +503,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
   // Mean shift vector, updating the joint pixel at each iteration
   RealVector meanShiftVector;
 
-  jointPixel.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
   meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
 
   while (!jointIt.IsAtEnd())
@@ -513,7 +521,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
       double meanShiftVectorSqNorm;
 
       //Calculate meanShiftVector
-      this->CalculateMeanShiftVector(jointImage, jointPixel, outputRegionForThread, meanShiftVector);
+      this->CalculateMeanShiftVector(m_JointImage, jointPixel, requestedRegion, meanShiftVector);
 
       // Compute mean shift vector squared norm
       meanShiftVectorSqNorm = 0;