Skip to content
Snippets Groups Projects
Commit 4743373f authored by Sebastien Harasse's avatar Sebastien Harasse
Browse files

ENH: Mean shift. Changing spatial output to display displacement instead of absolute position

parent 847d4dab
No related branches found
No related tags found
No related merge requests found
...@@ -191,6 +191,9 @@ public: ...@@ -191,6 +191,9 @@ public:
typedef TOutputIterationImage OutputIterationImageType; typedef TOutputIterationImage OutputIterationImageType;
typedef otb::VectorImage<RealType, InputImageType::ImageDimension> OutputSpatialImageType;
typedef typename OutputSpatialImageType::PixelType OutputSpatialPixelType;
typedef TKernel KernelType; typedef TKernel KernelType;
itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension); itkStaticConstMacro(ImageDimension, unsigned int, InputImageType::ImageDimension);
...@@ -215,7 +218,7 @@ public: ...@@ -215,7 +218,7 @@ public:
itkGetConstMacro(ModeSearchOptimization, bool); itkGetConstMacro(ModeSearchOptimization, bool);
/** Returns the const spatial image output */ /** Returns the const spatial image output */
const OutputImageType * GetSpatialOutput() const; const OutputSpatialImageType * GetSpatialOutput() const;
/** Returns the spectral image output */ /** Returns the spectral image output */
const OutputImageType * GetRangeOutput() const; const OutputImageType * GetRangeOutput() const;
/** Returns the mean shift vector computed at the last iteration for each pixel */ /** Returns the mean shift vector computed at the last iteration for each pixel */
...@@ -224,7 +227,7 @@ public: ...@@ -224,7 +227,7 @@ public:
const OutputIterationImageType * GetIterationOutput() const; const OutputIterationImageType * GetIterationOutput() const;
/** Returns the const spatial image output */ /** Returns the const spatial image output */
OutputImageType * GetSpatialOutput(); OutputSpatialImageType * GetSpatialOutput();
/** Returns the spectral image output */ /** Returns the spectral image output */
OutputImageType * GetRangeOutput(); OutputImageType * GetRangeOutput();
/** Returns the mean shift vector computed at the last iteration for each pixel */ /** Returns the mean shift vector computed at the last iteration for each pixel */
......
...@@ -56,7 +56,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -56,7 +56,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
} }
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage> template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
const typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputImageType * const typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputSpatialImageType *
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::GetSpatialOutput() const ::GetSpatialOutput() const
{ {
...@@ -64,11 +64,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -64,11 +64,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{ {
return 0; return 0;
} }
return static_cast<const OutputImageType *>(this->itk::ProcessObject::GetOutput(0)); return static_cast<const OutputSpatialImageType *>(this->itk::ProcessObject::GetOutput(0));
} }
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage> template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputImageType * typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputSpatialImageType *
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage> MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::GetSpatialOutput() ::GetSpatialOutput()
{ {
...@@ -76,7 +76,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -76,7 +76,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{ {
return 0; return 0;
} }
return static_cast<OutputImageType *>(this->itk::ProcessObject::GetOutput(0)); return static_cast<OutputSpatialImageType *>(this->itk::ProcessObject::GetOutput(0));
} }
...@@ -161,9 +161,9 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -161,9 +161,9 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
::AllocateOutputs() ::AllocateOutputs()
{ {
typename OutputImageType::Pointer spatialOutputPtr = this->GetSpatialOutput(); typename OutputSpatialImageType::Pointer spatialOutputPtr = this->GetSpatialOutput();
typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput(); typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
typename OutputImageType::Pointer metricOutputPtr = this->GetMetricOutput(); typename OutputMetricImageType::Pointer metricOutputPtr = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput(); typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
metricOutputPtr->SetBufferedRegion(metricOutputPtr->GetRequestedRegion()); metricOutputPtr->SetBufferedRegion(metricOutputPtr->GetRequestedRegion());
...@@ -194,7 +194,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -194,7 +194,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
{ {
this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension); // image lattice this->GetSpatialOutput()->SetNumberOfComponentsPerPixel(ImageDimension); // image lattice
} }
if (this->GetSpatialOutput()) if (this->GetRangeOutput())
{ {
this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel); this->GetRangeOutput()->SetNumberOfComponentsPerPixel(m_NumberOfComponentsPerPixel);
} }
...@@ -217,7 +217,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -217,7 +217,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
TInputImage * inPtr = const_cast<TInputImage *>(this->GetInput()); TInputImage * inPtr = const_cast<TInputImage *>(this->GetInput());
TOutputMetricImage * outMetricPtr = this->GetMetricOutput(); TOutputMetricImage * outMetricPtr = this->GetMetricOutput();
TOutputImage * outSpatialPtr = this->GetSpatialOutput(); OutputSpatialImageType * outSpatialPtr = this->GetSpatialOutput();
TOutputImage * outRangePtr = this->GetRangeOutput(); TOutputImage * outRangePtr = this->GetRangeOutput();
OutputIterationImageType * outIterationPtr = this->GetIterationOutput(); OutputIterationImageType * outIterationPtr = this->GetIterationOutput();
...@@ -279,7 +279,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -279,7 +279,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType; typedef itk::ImageRegionIterator<RealVectorImageType> JointImageIteratorType;
TOutputMetricImage * outMetricPtr = this->GetMetricOutput(); TOutputMetricImage * outMetricPtr = this->GetMetricOutput();
TOutputImage * outSpatialPtr = this->GetSpatialOutput(); OutputSpatialImageType * outSpatialPtr = this->GetSpatialOutput();
TOutputImage * outRangePtr = this->GetRangeOutput(); TOutputImage * outRangePtr = this->GetRangeOutput();
typename InputImageType::ConstPointer inputPtr = this->GetInput(); typename InputImageType::ConstPointer inputPtr = this->GetInput();
...@@ -485,7 +485,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -485,7 +485,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
this->AllocateOutputs(); this->AllocateOutputs();
// Retrieve output images pointers // Retrieve output images pointers
typename OutputImageType::Pointer spatialOutput = this->GetSpatialOutput(); typename OutputSpatialImageType::Pointer spatialOutput = this->GetSpatialOutput();
typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput(); typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
typename OutputMetricImageType::Pointer metricOutput = this->GetMetricOutput(); typename OutputMetricImageType::Pointer metricOutput = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput(); typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
...@@ -495,13 +495,14 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -495,13 +495,14 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// defines input and output iterators // defines input and output iterators
typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType;
typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType; typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType;
typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType; typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputIteratorWithIndexType;
typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType; typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
typename OutputImageType::PixelType rangePixel; typename OutputImageType::PixelType rangePixel;
typename OutputImageType::PixelType spatialPixel; typename OutputSpatialImageType::PixelType spatialPixel;
typename OutputMetricImageType::PixelType metricPixel; typename OutputMetricImageType::PixelType metricPixel;
typename OutputIterationImageType::PixelType iterationPixel; typename OutputIterationImageType::PixelType iterationPixel;
...@@ -520,7 +521,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -520,7 +521,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
JointImageIteratorType jointIt(m_JointImage, outputRegionForThread); JointImageIteratorType jointIt(m_JointImage, outputRegionForThread);
OutputIteratorType rangeIt(rangeOutput, outputRegionForThread); OutputIteratorType rangeIt(rangeOutput, outputRegionForThread);
OutputIteratorType spatialIt(spatialOutput, outputRegionForThread); OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread); OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread);
OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread); OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
...@@ -561,14 +562,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -561,14 +562,11 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// index of the currently processed output pixel // index of the currently processed output pixel
InputIndexType currentIndex; InputIndexType currentIndex;
if(m_ModeSearchOptimization) for (unsigned int comp = 0; comp < ImageDimension; comp++)
{ {
for (unsigned int comp = 0; comp < ImageDimension; comp++) currentIndex[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
{
currentIndex[comp] = jointPixel[comp] * m_SpatialBandwidth + 0.5;
}
pointCount = 0;
} }
pointCount = 0; // Note: used only in mode search optimization
iteration = 0; iteration = 0;
while ((iteration < m_MaxIterationNumber) && (!hasConverged)) while ((iteration < m_MaxIterationNumber) && (!hasConverged))
...@@ -661,7 +659,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm ...@@ -661,7 +659,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
for(unsigned int comp = 0; comp < ImageDimension; comp++) for(unsigned int comp = 0; comp < ImageDimension; comp++)
{ {
spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth; spatialPixel[comp] = jointPixel[comp] * m_SpatialBandwidth - currentIndex[comp];
} }
for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++) for(unsigned int comp = 0; comp < ImageDimension+m_NumberOfComponentsPerPixel; comp++)
......
...@@ -44,14 +44,16 @@ int otbMeanShiftImageFilter2(int argc, char * argv[]) ...@@ -44,14 +44,16 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
/* maxit - threshold */ /* maxit - threshold */
const unsigned int Dimension = 2; const unsigned int Dimension = 2;
typedef float PixelType; typedef float PixelType;
typedef double KernelType; typedef double KernelType;
typedef otb::VectorImage<PixelType, Dimension> ImageType; typedef otb::VectorImage<PixelType, Dimension> ImageType;
typedef otb::ImageFileReader<ImageType> ReaderType; typedef otb::ImageFileReader<ImageType> ReaderType;
typedef otb::ImageFileWriter<ImageType> WriterType; typedef otb::ImageFileWriter<ImageType> WriterType;
typedef otb::MeanShiftImageFilter2<ImageType, ImageType> FilterType; typedef otb::MeanShiftImageFilter2<ImageType, ImageType> FilterType;
typedef FilterType::OutputIterationImageType IterationImageType; typedef FilterType::OutputIterationImageType IterationImageType;
typedef otb::ImageFileWriter<IterationImageType> IterationWriterType; typedef otb::ImageFileWriter<IterationImageType> IterationWriterType;
typedef FilterType::OutputSpatialImageType SpatialImageType;
typedef otb::ImageFileWriter<SpatialImageType> SpatialWriterType;
// Instantiating object // Instantiating object
FilterType::Pointer filter = FilterType::New(); FilterType::Pointer filter = FilterType::New();
...@@ -66,7 +68,7 @@ int otbMeanShiftImageFilter2(int argc, char * argv[]) ...@@ -66,7 +68,7 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
filter->SetMaxIterationNumber(maxiterationnumber); filter->SetMaxIterationNumber(maxiterationnumber);
filter->SetInput(reader->GetOutput()); filter->SetInput(reader->GetOutput());
//filter->SetNumberOfThreads(1); //filter->SetNumberOfThreads(1);
WriterType::Pointer writer1 = WriterType::New(); SpatialWriterType::Pointer writer1 = SpatialWriterType::New();
WriterType::Pointer writer2 = WriterType::New(); WriterType::Pointer writer2 = WriterType::New();
WriterType::Pointer writer3 = WriterType::New(); WriterType::Pointer writer3 = WriterType::New();
IterationWriterType::Pointer writer4 = IterationWriterType::New(); IterationWriterType::Pointer writer4 = IterationWriterType::New();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment