Commit 4743373f authored by Sebastien Harasse's avatar Sebastien Harasse

ENH: Mean shift. Changing spatial output to display displacement instead of absolute position

parent 847d4dab
......@@ -191,6 +191,9 @@ public:
typedef TOutputIterationImage OutputIterationImageType;
typedef otb::VectorImage<RealType, InputImageType::ImageDimension> OutputSpatialImageType;
typedef typename OutputSpatialImageType::PixelType OutputSpatialPixelType;
typedef TKernel KernelType;
itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension);
......@@ -215,7 +218,7 @@ public:
itkGetConstMacro(ModeSearchOptimization, bool);
/** Returns the const spatial image output */
const OutputImageType * GetSpatialOutput() const;
const OutputSpatialImageType * GetSpatialOutput() const;
/** Returns the spectral image output */
const OutputImageType * GetRangeOutput() const;
/** Returns the mean shift vector computed at the last iteration for each pixel */
......@@ -224,7 +227,7 @@ public:
const OutputIterationImageType * GetIterationOutput() const;
/** Returns the const spatial image output */
OutputImageType * GetSpatialOutput();
OutputSpatialImageType * GetSpatialOutput();
/** Returns the spectral image output */
OutputImageType * GetRangeOutput();
/** Returns the mean shift vector computed at the last iteration for each pixel */
......
......@@ -56,7 +56,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
}
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
const typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputImageType *
const typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputSpatialImageType *
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::GetSpatialOutput() const
{
......@@ -64,11 +64,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
return 0;
}
return static_cast<const OutputImageType *>(this->itk::ProcessObject::GetOutput(0));
return static_cast<const OutputSpatialImageType *>(this->itk::ProcessObject::GetOutput(0));
}
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputImageType *
typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputSpatialImageType *
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::GetSpatialOutput()
{
......@@ -76,7 +76,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
return 0;
}
return static_cast<OutputImageType *>(this->itk::ProcessObject::GetOutput(0));
return static_cast<OutputSpatialImageType *>(this->itk::ProcessObject::GetOutput(0));
}
......@@ -161,9 +161,9 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
::AllocateOutputs()
{
typename OutputImageType::Pointer spatialOutputPtr = this->GetSpatialOutput();
typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
typename OutputImageType::Pointer metricOutputPtr = this->GetMetricOutput();
typename OutputSpatialImageType::Pointer spatialOutputPtr = this->GetSpatialOutput();
typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
typename OutputMetricImageType::Pointer metricOutputPtr = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
metricOutputPtr->SetBufferedRegion(metricOutputPtr->GetRequestedRegion());
......@@ -194,7 +194,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{
this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension); // image lattice
}
if (this->GetSpatialOutput())
if (this->GetRangeOutput())
{
this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
}
......@@ -217,7 +217,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
TInputImage * inPtr = const_cast<TInputImage *>(this->GetInput());
TOutputMetricImage * outMetricPtr = this->GetMetricOutput();
TOutputImage * outSpatialPtr = this->GetSpatialOutput();
OutputSpatialImageType * outSpatialPtr = this->GetSpatialOutput();
TOutputImage * outRangePtr = this->GetRangeOutput();
OutputIterationImageType * outIterationPtr = this->GetIterationOutput();
......@@ -279,7 +279,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
TOutputMetricImage * outMetricPtr = this->GetMetricOutput();
TOutputImage * outSpatialPtr = this->GetSpatialOutput();
OutputSpatialImageType * outSpatialPtr = this->GetSpatialOutput();
TOutputImage * outRangePtr = this->GetRangeOutput();
typename InputImageType::ConstPointer inputPtr = this->GetInput();
......@@ -485,7 +485,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
this->AllocateOutputs();
// Retrieve output images pointers
typename OutputImageType::Pointer spatialOutput = this->GetSpatialOutput();
typename OutputSpatialImageType::Pointer spatialOutput = this->GetSpatialOutput();
typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
typename OutputMetricImageType::Pointer metricOutput = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
......@@ -495,13 +495,14 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// defines input and output iterators
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType;
typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
typename OutputImageType::PixelType rangePixel;
typename OutputImageType::PixelType spatialPixel;
typename OutputSpatialImageType::PixelType spatialPixel;
typename OutputMetricImageType::PixelType metricPixel;
typename OutputIterationImageType::PixelType iterationPixel;
......@@ -520,7 +521,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
OutputIteratorType spatialIt(spatialOutput, outputRegionForThread);
OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread);
OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
......@@ -561,14 +562,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// index of the currently processed output pixel
InputIndexType currentIndex;
if(m_ModeSearchOptimization)
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
currentIndex[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
}
pointCount = 0;
currentIndex[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
}
pointCount = 0; // Note: used only in mode search optimization
iteration = 0;
while ((iteration < m_MaxIterationNumber) && (!hasConverged))
......@@ -661,7 +659,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
for(unsigned int comp = 0; comp < ImageDimension; comp++)
{
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth;
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp];
}
for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++)
......
......@@ -44,14 +44,16 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
/* maxit - threshold */
const unsigned int Dimension = 2;
typedef float PixelType;
typedef double KernelType;
typedef otb::VectorImage<PixelType, Dimension> ImageType;
typedef otb::ImageFileReader<ImageType> ReaderType;
typedef otb::ImageFileWriter<ImageType> WriterType;
typedef float PixelType;
typedef double KernelType;
typedef otb::VectorImage<PixelType, Dimension> ImageType;
typedef otb::ImageFileReader<ImageType> ReaderType;
typedef otb::ImageFileWriter<ImageType> WriterType;
typedef otb::MeanShiftImageFilter2<ImageType, ImageType> FilterType;
typedef FilterType::OutputIterationImageType IterationImageType;
typedef otb::ImageFileWriter<IterationImageType> IterationWriterType;
typedef FilterType::OutputIterationImageType IterationImageType;
typedef otb::ImageFileWriter<IterationImageType> IterationWriterType;
typedef FilterType::OutputSpatialImageType SpatialImageType;
typedef otb::ImageFileWriter<SpatialImageType> SpatialWriterType;
// Instantiating object
FilterType::Pointer filter = FilterType::New();
......@@ -66,7 +68,7 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
filter->SetMaxIterationNumber(maxiterationnumber);
filter->SetInput(reader->GetOutput());
//filter->SetNumberOfThreads(1);
WriterType::Pointer writer1 = WriterType::New();
SpatialWriterType::Pointer writer1 = SpatialWriterType::New();
WriterType::Pointer writer2 = WriterType::New();
WriterType::Pointer writer3 = WriterType::New();
IterationWriterType::Pointer writer4 = IterationWriterType::New();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment