Commit 7dd1809b authored by Sebastien Harasse's avatar Sebastien Harasse

ENH: Mean shift. output image of region labels (Work in progress)

parent 91b293b9
......@@ -193,6 +193,9 @@ public:
typedef TOutputIterationImage OutputIterationImageType;
typedef unsigned long LabelType;
typedef otb::Image<LabelType, InputImageType::ImageDimension> OutputLabelImageType;
typedef otb::VectorImage<RealType, InputImageType::ImageDimension> OutputSpatialImageType;
typedef typename OutputSpatialImageType::Pointer OutputSpatialImagePointerType;
typedef typename OutputSpatialImageType::PixelType OutputSpatialPixelType;
......@@ -228,6 +231,9 @@ public:
const OutputMetricImageType * GetMetricOutput() const;
/** Returns the number of iterations done at each pixel */
const OutputIterationImageType * GetIterationOutput() const;
/** Returns the image of region labels */
const OutputLabelImageType * GetLabelOutput() const;
/** Returns the const spatial image output */
OutputSpatialImageType * GetSpatialOutput();
......@@ -237,6 +243,8 @@ public:
OutputMetricImageType * GetMetricOutput();
/** Returns the number of iterations done at each pixel */
OutputIterationImageType * GetIterationOutput();
/** Returns the image of region labels */
OutputLabelImageType * GetLabelOutput();
protected:
......
......@@ -40,11 +40,12 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
m_Threshold=1e-3;
m_ModeSearchOptimization = true;
this->SetNumberOfOutputs(4);
this->SetNumberOfOutputs(5);
this->SetNthOutput(0, OutputSpatialImageType::New());
this->SetNthOutput(1, OutputImageType::New());
this->SetNthOutput(2, OutputMetricImageType::New());
this->SetNthOutput(3, OutputIterationImageType::New());
this->SetNthOutput(4, OutputLabelImageType::New());
}
......@@ -154,6 +155,30 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
return static_cast<OutputIterationImageType *>(this->itk::ProcessObject::GetOutput(3));
}
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputLabelImageType *
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::GetLabelOutput()
{
if (this->GetNumberOfOutputs() < 5)
{
return 0;
}
return static_cast<OutputLabelImageType *>(this->itk::ProcessObject::GetOutput(4));
}
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
const typename MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>::OutputLabelImageType *
MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricImage, TOutputIterationImage>
::GetLabelOutput() const
{
if (this->GetNumberOfOutputs() < 5)
{
return 0;
}
return static_cast<OutputLabelImageType *>(this->itk::ProcessObject::GetOutput(4));
}
template <class TInputImage, class TOutputImage, class TKernel, class TNorm, class TOutputMetricImage, class TOutputIterationImage>
void
......@@ -165,6 +190,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typename OutputImageType::Pointer rangeOutputPtr = this->GetRangeOutput();
typename OutputMetricImageType::Pointer metricOutputPtr = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutputPtr = this->GetIterationOutput();
typename OutputLabelImageType::Pointer labelOutputPtr = this->GetLabelOutput();
metricOutputPtr->SetBufferedRegion(metricOutputPtr->GetRequestedRegion());
metricOutputPtr->Allocate();
......@@ -178,6 +204,8 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
iterationOutputPtr->SetBufferedRegion(iterationOutputPtr->GetRequestedRegion());
iterationOutputPtr->Allocate();
labelOutputPtr->SetBufferedRegion(labelOutputPtr->GetRequestedRegion());
labelOutputPtr->Allocate();
}
......@@ -483,6 +511,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typename OutputImageType::Pointer rangeOutput = this->GetRangeOutput();
typename OutputMetricImageType::Pointer metricOutput = this->GetMetricOutput();
typename OutputIterationImageType::Pointer iterationOutput = this->GetIterationOutput();
typename OutputLabelImageType::Pointer labelOutput = this->GetLabelOutput();
// Get input image pointer
typename InputImageType::ConstPointer input = this->GetInput();
......@@ -492,6 +521,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
typedef itk::ImageRegionIterator<OutputSpatialImageType> OutputSpatialIteratorType;
typedef itk::ImageRegionIterator<OutputMetricImageType> OutputMetricIteratorType;
typedef itk::ImageRegionIterator<OutputIterationImageType> OutputIterationIteratorType;
typedef itk::ImageRegionIterator<OutputLabelImageType> OutputLabelIteratorType;
unsigned int jointDimension = ImageDimension + m_NumberOfComponentsPerPixel;
......@@ -532,6 +562,7 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
OutputSpatialIteratorType spatialIt(spatialOutput, outputRegionForThread);
OutputMetricIteratorType metricIt(metricOutput, outputRegionForThread);
OutputIterationIteratorType iterationIt(iterationOutput, outputRegionForThread);
OutputLabelIteratorType labelIt(labelOutput, outputRegionForThread);
typedef itk::ImageRegionIterator<ModeTableImageType> ModeTableImageIteratorType;
ModeTableImageIteratorType modeTableIt(m_ModeTable, outputRegionForThread);
......@@ -559,16 +590,24 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// Number of times an already processed candidate pixel is encountered, resulting in no
// further computation (Used for statistics only)
unsigned int numBreaks = 0;
// index of the current pixel updated during the mean shift loop
InputIndexType modeCandidate;
// Number of labels already assigned. This is also used to find the next unused label
LabelType numLabels = 0;
for (; !jointIt.IsAtEnd();
++jointIt, ++rangeIt, ++spatialIt, ++metricIt, ++iterationIt, ++modeTableIt, progress.CompletedPixel())
++jointIt, ++rangeIt, ++spatialIt, ++metricIt, ++iterationIt,
++modeTableIt, ++labelIt, progress.CompletedPixel())
{
// if pixel has been already processed (by mode search optimization), skip
typename ModeTableImageType::InternalPixelType currentPixelMode;
currentPixelMode = modeTableIt.Get();
if(m_ModeSearchOptimization && currentPixelMode == 1)
{
numBreaks++;
continue;
}
bool hasConverged = false;
......@@ -587,8 +626,6 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
if (m_ModeSearchOptimization)
{
// index of the current pixel updated during the mean shift loop
InputIndexType modeCandidate;
// Find index of the pixel closest to the current jointPixel (not normalized by bandwidth)
for (unsigned int comp = 0; comp < ImageDimension; comp++)
{
......@@ -690,11 +727,24 @@ MeanShiftImageFilter2<TInputImage, TOutputImage, TKernel, TNorm, TOutputMetricIm
// Update the mode table now that the current pixel has been assigned
modeTableIt.Set(1); // m_ModeTable->SetPixel(currentIndex, 1);
// If the loop exited with hasConverged, then we have a new mode
LabelType label;
if (hasConverged)
{
numLabels++;
label = numLabels;
} else // the loop exited through a break. Use the already assigned mode label
{
label = labelOutput->GetPixel(modeCandidate);
}
labelIt.Set(label);
// Also assign all points in the list to the same mode
for (unsigned int i = 0; i < pointCount; i++)
{
rangeOutput->SetPixel(pointList[i], rangePixel);
m_ModeTable->SetPixel(pointList[i], 1);
labelOutput->SetPixel(pointList[i], label);
}
}
......
......@@ -1141,6 +1141,7 @@ ADD_TEST(bfTvMeanShiftImageFilter2 ${BASICFILTERS_TESTS9}
${TEMP}/bfMeanShiftImageFilterSpectralOutput.tif
${TEMP}/bfMeanShiftImageFilterMetricOutput.tif
${TEMP}/bfMeanShiftImageFilterIterationOutput.tif
${TEMP}/bfMeanShiftImageFilterLabelOutput.tif
4 50 0.1 100
)
......@@ -1162,6 +1163,7 @@ ADD_TEST(bfTvMeanShiftImageFilter2Mul ${BASICFILTERS_TESTS9}
${TEMP}/bfMeanShiftImageFilterSpectralOutputMul.tif
${TEMP}/bfMeanShiftImageFilterMetricOutputMul.tif
${TEMP}/bfMeanShiftImageFilterIterationOutputMul.tif
${TEMP}/bfMeanShiftImageFilterLabelOutputMul.tif
4 50 0.1 100
)
......@@ -2652,8 +2654,8 @@ ADD_TEST(bfTvTileImageFilter ${BASICFILTERS_TESTS15}
2 2
${TEMP}/bfTvTileImageFilterOutput.tif
${INPUTDATA}/ROI_QB_MUL_4.tif
${INPUTDATA}/ROI_QB_MUL_4.tif
${INPUTDATA}/ROI_QB_MUL_4.tif
${INPUTDATA}/ROI_QB_MUL_4.tif
${INPUTDATA}/ROI_QB_MUL_4.tif
${INPUTDATA}/ROI_QB_MUL_4.tif
)
......
......@@ -24,10 +24,10 @@
int otbMeanShiftImageFilter2(int argc, char * argv[])
{
if (argc != 10)
if (argc != 11)
{
std::cerr << "Usage: " << argv[0] <<
" infname spatialfname spectralfname metricfname iterationfname spatialBandwidth rangeBandwidth threshold maxiterationnumber"
" infname spatialfname spectralfname metricfname iterationfname labelfname spatialBandwidth rangeBandwidth threshold maxiterationnumber"
<< std::endl;
return EXIT_FAILURE;
}
......@@ -37,10 +37,11 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
const char * spectralfname = argv[3];
const char * metricfname = argv[4];
const char * iterationfname = argv[5];
const double spatialBandwidth = atof(argv[6]);
const double rangeBandwidth = atof(argv[7]);
const double threshold = atof(argv[8]);
const unsigned int maxiterationnumber = atoi(argv[9]);
const char * labelfname = argv[6];
const double spatialBandwidth = atof(argv[7]);
const double rangeBandwidth = atof(argv[8]);
const double threshold = atof(argv[9]);
const unsigned int maxiterationnumber = atoi(argv[10]);
/* maxit - threshold */
const unsigned int Dimension = 2;
......@@ -54,6 +55,8 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
typedef otb::ImageFileWriter<IterationImageType> IterationWriterType;
typedef FilterType::OutputSpatialImageType SpatialImageType;
typedef otb::ImageFileWriter<SpatialImageType> SpatialWriterType;
typedef FilterType::OutputLabelImageType LabelImageType;
typedef otb::ImageFileWriter<LabelImageType> LabelWriterType;
// Instantiating object
FilterType::Pointer filter = FilterType::New();
......@@ -72,22 +75,25 @@ int otbMeanShiftImageFilter2(int argc, char * argv[])
WriterType::Pointer writer2 = WriterType::New();
WriterType::Pointer writer3 = WriterType::New();
IterationWriterType::Pointer writer4 = IterationWriterType::New();
LabelWriterType::Pointer writer5 = LabelWriterType::New();
writer1->SetFileName(spatialfname);
writer2->SetFileName(spectralfname);
writer3->SetFileName(metricfname);
writer4->SetFileName(iterationfname);
writer5->SetFileName(labelfname);
writer1->SetInput(filter->GetSpatialOutput());
writer2->SetInput(filter->GetRangeOutput());
writer3->SetInput(filter->GetMetricOutput());
writer4->SetInput(filter->GetIterationOutput());
writer5->SetInput(filter->GetLabelOutput());
writer1->Update();
writer2->Update();
writer3->Update();
writer4->Update();
writer5->Update();
return EXIT_SUCCESS;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment