diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.h b/Code/BasicFilters/otbMeanShiftImageFilter2.h
index a654640674bd84d3789387122aee4c06d29a46b6..2a901ac53c1505bc8292b57a021b6e0efe8b8e05 100644
--- a/Code/BasicFilters/otbMeanShiftImageFilter2.h
+++ b/Code/BasicFilters/otbMeanShiftImageFilter2.h
@@ -141,7 +141,9 @@ public:
 
   itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension);
 
-  typedef itk::VariableLengthVector<RealType>       RealVector;
+  typedef itk::VariableLengthVector<RealType>         RealVector;
+
+  typedef itk::VectorImage<RealType, InputImageType::ImageDimension> RealVectorImageType;
 
   /** Setters / Getters */
   itkSetMacro(SpatialBandwidth, RealType);
@@ -215,7 +217,7 @@ protected:
   /** PrintSelf method */
   virtual void PrintSelf(std::ostream& os, itk::Indent indent) const;
 
-  virtual void CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector);
+  virtual void CalculateMeanShiftVector(typename RealVectorImageType::Pointer inputImagePtr, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector);
 
 private:
   MeanShiftImageFilter2(const Self &); //purposely not implemented
diff --git a/Code/BasicFilters/otbMeanShiftImageFilter2.txx b/Code/BasicFilters/otbMeanShiftImageFilter2.txx
index bcd0ce71bcdd77edd4daf2569561ce9cf62c5ae0..729c0b8d0c13229fcdc136cd638b6b7c695906b3 100644
--- a/Code/BasicFilters/otbMeanShiftImageFilter2.txx
+++ b/Code/BasicFilters/otbMeanShiftImageFilter2.txx
@@ -291,85 +291,83 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
 template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
 void
 MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
-::CalculateMeanShiftVector(typename InputImageType::ConstPointer inputImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
+::CalculateMeanShiftVector(typename RealVectorImageType::Pointer jointImage, RealVector jointPixel, const OutputRegionType& outputRegion, RealVector & meanShiftVector)
  {
   RealVector jointNeighbor;
 
-  jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
-
   RealType weightSum = 0;
-  InputPixelType inputPixel;
   InputIndexType inputIndex;
   InputIndexType regionIndex;
   InputSizeType  regionSize;
   RegionType neighborhoodRegion;
 
   meanShiftVector.Fill(0.);
+  jointNeighbor.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
 
   // Calculates current pixel neighborhood region, restricted to the output image region
   for(unsigned int comp = 0; comp < ImageDimension; ++comp)
     {
     long int indexRight;
-    inputIndex[comp] = jointPixel[comp];
+    inputIndex[comp] = jointPixel[comp] * m_SpatialBandwidth;
 
     regionIndex[comp] = vcl_max(static_cast<long int>(outputRegion.GetIndex().GetElement(comp)), static_cast<long int>(inputIndex[comp] - m_SpatialRadius[comp]));
     indexRight = vcl_min(static_cast<long int>(outputRegion.GetIndex().GetElement(comp) + outputRegion.GetSize().GetElement(comp) - 1), static_cast<long int>(inputIndex[comp] + m_SpatialRadius[comp]));
 
-    // regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1));
-    regionSize[comp] = indexRight - regionIndex[comp] + 1;
+    regionSize[comp] = vcl_max(0l, indexRight - static_cast<long int>(regionIndex[comp] + 1));
     }
 
   neighborhoodRegion.SetIndex(regionIndex); // TODO Handle region borders
   neighborhoodRegion.SetSize(regionSize); //TODO Add +1 for each dimension
 
-  // An iterator on the neighborhood of the current pixel
-  itk::ImageRegionConstIteratorWithIndex<InputImageType> it(inputImage, neighborhoodRegion);
+  // An iterator on the neighborhood of the current pixel (in joint
+  // spatial-range domain)
+  itk::ImageRegionConstIteratorWithIndex<RealVectorImageType> it(jointImage, neighborhoodRegion);
 
-  //std::cout << neighborhoodRegion << std::endl;
   it.GoToBegin();
   while(!it.IsAtEnd())
     {
-    inputIndex = it.GetIndex();
-    inputPixel = it.Get();
     RealVector diff;
     RealType norm2;
     RealType weight;
 
-    // Write the current pixel of the neighborhood in the joint spatial-range domain
-    for (unsigned int comp = 0; comp < ImageDimension; comp++)
-      {
-      jointNeighbor[comp] = inputIndex[comp];
-      }
-    for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
-      {
-      jointNeighbor[ImageDimension + comp] = inputPixel[comp];
-      }
+    jointNeighbor = it.Get();
 
     // Calculate the squared norm of the difference
     diff = jointNeighbor - jointPixel;
 
-    // Scale diff vector elements by the bandwidth
+    // Compute the squared norm of the difference
+    // This is the L2 norm, TODO: replace by the templated norm
+    norm2 = diff.GetSquaredNorm();
+    // Compute pixel weight from kernel
+    // TODO : replace by the templated kernel
+    weight = (norm2 <= 1.0)? 1.0 : 0.0;
+
+    /*
+    // The following code is an alternative way to compute norm2 and weight
+    // It separates the norms of spatial and range elements
+    RealType spatialNorm2;
+    RealType rangeNorm2;
+    spatialNorm2 = 0;
     for (unsigned int comp = 0; comp < ImageDimension; comp++)
       {
-      diff[comp] /= m_SpatialBandwidth;
-      }
-    for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
-      {
-      diff[ImageDimension + comp] /= m_RangeBandwidth;
+      spatialNorm2 += diff[comp] * diff[comp];
       }
 
-    // Compute the squared norm of the difference
-    // This is the L_inf norm, TODO: replace by the templated norm
-    norm2 = 0;
-    for (unsigned int comp = 0; comp < ImageDimension + m_NumberOfComponentsPerPixel; comp++)
+    if(spatialNorm2 >= 1.0)
       {
-      norm2 += vcl_max(norm2, vcl_abs(diff[comp]));
+      weight = 0;
       }
-    norm2 *= norm2;
+    else
+      {
+      rangeNorm2 = 0;
+      for (unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
+        {
+        rangeNorm2 += diff[ImageDimension + comp] * diff[ImageDimension + comp];
+        }
 
-    // Compute pixel weight from kernel
-    // TODO : replace by the templated kernel
-    weight = (norm2 <= 1.0)? 1.0 : 0.0;
+      weight = (rangeNorm2 <= 1.0)? 1.0 : 0.0;
+      }
+    */
 
     // Update sum of weights
     weightSum += weight;
@@ -425,17 +423,54 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
   typename OutputMetricImageType::PixelType metricPixel;
   typename OutputIterationImageType::PixelType iterationPixel;
 
+  InputIndexType index;
+
+  // Pixel in the joint spatial-range domain
+  RealVector jointPixel;
 
   itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
 
 
+  // Declare and allocate an image in the joint space-range domain containing
+  // scaled values
+  typename RealVectorImageType::Pointer jointImage = RealVectorImageType::New();
+  jointImage->SetNumberOfComponentsPerPixel(ImageDimension + m_NumberOfComponentsPerPixel);
+  jointImage->SetRegions(outputRegionForThread);
+  jointImage->Allocate();
+
+  typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
+  JointImageIteratorType jointIt(jointImage, outputRegionForThread);
   InputIteratorWithIndexType inputIt(input, inputRegionForThread);
+  //Initialize the joint image with scaled values
+  inputIt.GoToBegin();
+  jointIt.GoToBegin();
+
+  while (!inputIt.IsAtEnd())
+    {
+    inputPixel = inputIt.Get();
+    index = inputIt.GetIndex();
+
+    jointPixel = jointIt.Get();
+    for(unsigned int comp = 0; comp < ImageDimension; comp++)
+      {
+      jointPixel[comp] = index[comp] / m_SpatialBandwidth;
+      }
+    for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
+      {
+      jointPixel[ImageDimension + comp] = inputPixel[comp] / m_RangeBandwidth;
+      }
+    jointIt.Set(jointPixel);
+
+    ++inputIt;
+    ++jointIt;
+    }
+
   OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
   OutputIteratorType spatialIt(spatialOutput, outputRegionForThread);
   OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread);
   OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
 
-  inputIt.GoToBegin();
+  jointIt.GoToBegin();
   rangeIt.GoToBegin();
   spatialIt.GoToBegin();
   metricIt.GoToBegin();
@@ -443,39 +478,21 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
 
   unsigned int iteration = 0;
 
-
-  // Pixel in the joint spatial-range domain
-  RealVector jointPixel;
-
   // Mean shift vector, updating the joint pixel at each iteration
   RealVector meanShiftVector;
 
-
   jointPixel.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
   meanShiftVector.SetSize(ImageDimension + m_NumberOfComponentsPerPixel);
 
-
-  while (!inputIt.IsAtEnd())
+  while (!jointIt.IsAtEnd())
     {
     bool hasConverged = false;
 
-    InputIndexType index = inputIt.GetIndex();
-    inputPixel = inputIt.Get();
     rangePixel = rangeIt.Get();
     spatialPixel = spatialIt.Get();
     metricPixel = metricIt.Get();
 
-
-    // Initialize pixel in the joint spatial-range domain
-    for(unsigned int comp = 0; comp < ImageDimension; ++comp)
-      {
-      jointPixel.SetElement(comp, index[comp]);
-      }
-
-    for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; ++comp)
-      {
-      jointPixel.SetElement(ImageDimension+comp, inputPixel[comp]);
-      }
+    jointPixel = jointIt.Get();
 
     iteration = 0;
     while ((iteration < m_MaxIterationNumber) && (!hasConverged))
@@ -483,8 +500,19 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
       double meanShiftVectorSqNorm;
 
       //Calculate meanShiftVector
-      this->CalculateMeanShiftVector(input, jointPixel, outputRegionForThread, meanShiftVector);
-      meanShiftVectorSqNorm = meanShiftVector.GetSquaredNorm();
+      this->CalculateMeanShiftVector(jointImage, jointPixel, outputRegionForThread, meanShiftVector);
+
+      // Compute mean shift vector squared norm
+      meanShiftVectorSqNorm = 0;
+      for(unsigned int comp = 0; comp < ImageDimension; comp++)
+        {
+        meanShiftVectorSqNorm += meanShiftVector[comp] * meanShiftVector[comp] * m_SpatialBandwidth*m_SpatialBandwidth;
+        }
+      for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
+        {
+        meanShiftVectorSqNorm += meanShiftVector[ImageDimension + comp] * meanShiftVector[ImageDimension + comp] * m_RangeBandwidth*m_RangeBandwidth;
+        }
+
       jointPixel += meanShiftVector;
 
       //TODO replace SSD Test with templated metric
@@ -494,11 +522,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
 
     for(unsigned int comp = 0; comp < m_NumberOfComponentsPerPixel; comp++)
       {
-      rangePixel[comp] = jointPixel[ImageDimension + comp];
+      rangePixel[comp] = jointPixel[ImageDimension + comp] * m_RangeBandwidth;
       }
     for(unsigned int comp = 0; comp < ImageDimension; comp++)
       {
-      spatialPixel[comp] = jointPixel[comp];
+      spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth;
       }
 
     for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++)
@@ -513,7 +541,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
     iterationPixel = iteration;
     iterationIt.Set(iterationPixel);
 
-    ++inputIt;
+    ++jointIt;
     ++rangeIt;
     ++spatialIt;
     ++metricIt;