From 53960c19970ff59806725635c93aebfd9462f3cb Mon Sep 17 00:00:00 2001
From: Arnaud Jaen <arnaud.jaen@c-s.fr>
Date: Thu, 10 May 2012 17:56:53 +0200
Subject: [PATCH] ENH: Add a mask in otbStreamingVectorizedSegmentation. Pixel
 with values of 0 in the mask will not be suitable for vectorization.

---
 .../OBIA/otbLabelImageToOGRDataSourceFilter.h |   6 +-
 .../otbLabelImageToOGRDataSourceFilter.txx    | 109 ++++++++++++++++--
 .../otbPersistentImageToOGRDataFilter.txx     |   2 +
 .../otbStreamingVectorizedSegmentationOGR.h   |  14 +++
 .../otbStreamingVectorizedSegmentationOGR.txx |  48 ++++++--
 Testing/Code/OBIA/CMakeLists.txt              |   3 +-
 .../otbStreamingVectorizedSegmentationOGR.cxx |  23 ++--
 7 files changed, 176 insertions(+), 29 deletions(-)

diff --git a/Code/OBIA/otbLabelImageToOGRDataSourceFilter.h b/Code/OBIA/otbLabelImageToOGRDataSourceFilter.h
index 83dbae6951..20cccb2f87 100644
--- a/Code/OBIA/otbLabelImageToOGRDataSourceFilter.h
+++ b/Code/OBIA/otbLabelImageToOGRDataSourceFilter.h
@@ -24,7 +24,7 @@
 namespace otb
 {
 
-class OGRDataSourceWrapper;
+// class OGRDataSourceWrapper;
 
 /** \class LabelImageToOGRDataSourceFilter
  *  \brief this class uses GDALPolygonize method to transform a Label image into
@@ -72,6 +72,10 @@ public:
   virtual void SetInput(const InputImageType *input);
   virtual const InputImageType * GetInput(void);
   
+  /** Set/Get the input mask image. All pixels in the mask with a value of 0 will not be considered suitable for collection as polygons */
+  virtual void SetInputMask(const InputImageType *input);
+  virtual const InputImageType * GetInputMask(void);
+  
   itkSetMacro(FieldName, std::string);
   itkGetMacro(FieldName, std::string);
   
diff --git a/Code/OBIA/otbLabelImageToOGRDataSourceFilter.txx b/Code/OBIA/otbLabelImageToOGRDataSourceFilter.txx
index c1818424bd..4109ea053e 100644
--- a/Code/OBIA/otbLabelImageToOGRDataSourceFilter.txx
+++ b/Code/OBIA/otbLabelImageToOGRDataSourceFilter.txx
@@ -38,6 +38,7 @@ template <class TInputImage>
 LabelImageToOGRDataSourceFilter<TInputImage>
 ::LabelImageToOGRDataSourceFilter() : m_FieldName("DN"), m_Use8Connected(false)
 {
+   this->SetNumberOfInputs(2);
    this->SetNumberOfRequiredInputs(1);
    this->SetNumberOfRequiredOutputs(1);
    
@@ -86,6 +87,28 @@ LabelImageToOGRDataSourceFilter<TInputImage>
   return static_cast<const InputImageType *>(this->Superclass::GetInput(0));
 }
 
+template <class TInputImage>
+void
+LabelImageToOGRDataSourceFilter<TInputImage>
+::SetInputMask(const InputImageType *input)
+{
+  this->Superclass::SetNthInput(1, const_cast<InputImageType *>(input));
+}
+
+template <class TInputImage>
+const typename LabelImageToOGRDataSourceFilter<TInputImage>
+::InputImageType *
+LabelImageToOGRDataSourceFilter<TInputImage>
+::GetInputMask(void)
+{
+  if (this->GetNumberOfInputs() < 2)
+    {
+    return 0;
+    }
+
+  return static_cast<const InputImageType *>(this->Superclass::GetInput(1));
+}
+
 template <class TInputImage>
 void
 LabelImageToOGRDataSourceFilter<TInputImage>
@@ -102,10 +125,17 @@ LabelImageToOGRDataSourceFilter<TInputImage>
     {
     return;
     }
-
   // The input is necessarily the largest possible region.
-  // For a streamed implementation, use the StreamingLineSegmentDetector filter
   input->SetRequestedRegionToLargestPossibleRegion();
+  
+  typename InputImageType::Pointer mask  =
+    const_cast<InputImageType *> (this->GetInputMask());
+  if(!mask)
+  {
+   return;
+  }
+  // The input is necessarily the largest possible region.
+  mask->SetRequestedRegionToLargestPossibleRegion();
 }
 
 
@@ -118,15 +148,16 @@ LabelImageToOGRDataSourceFilter<TInputImage>
     {
     itkExceptionMacro(<< "Not streamed filter. ERROR : requested region is not the largest possible region.");
     }
-    
-    typename InputImageType::Pointer inImage = const_cast<InputImageType *>(this->GetInput());
 
-    SizeType size = this->GetInput()->GetLargestPossibleRegion().GetSize();
-    
-    unsigned int nbBands = this->GetInput()->GetNumberOfComponentsPerPixel();
-    unsigned int bytePerPixel = sizeof(InputPixelType);
+    SizeType size;
+    unsigned int nbBands = 0;
+    unsigned int bytePerPixel = 0;
 
-    /** Convert Input image into a OGRLayer using GDALPolygonize */
+    /* Convert the input image into a GDAL raster needed by GDALPolygonize */
+    size = this->GetInput()->GetLargestPossibleRegion().GetSize();
+    nbBands = this->GetInput()->GetNumberOfComponentsPerPixel();
+    bytePerPixel = sizeof(InputPixelType);
+    
     // buffer casted in unsigned long cause under Win32 the adress
     // don't begin with 0x, the adress in not interpreted as
     // hexadecimal but alpha numeric value, then the conversion to
@@ -192,8 +223,64 @@ LabelImageToOGRDataSourceFilter<TInputImage>
       options=option;
     }
     
-    GDALPolygonize(dataset->GetRasterBand(1), NULL, &outputLayer.ogr(), 0, options, NULL, NULL);
-    
+    /* Convert the mask input into a GDAL raster needed by GDALPolygonize */
+    typename InputImageType::ConstPointer inputMask = this->GetInputMask();
+    if (!inputMask.IsNull())
+    {
+      size = this->GetInputMask()->GetLargestPossibleRegion().GetSize();
+      nbBands = this->GetInputMask()->GetNumberOfComponentsPerPixel();
+      bytePerPixel = sizeof(InputPixelType);
+      // buffer casted in unsigned long cause under Win32 the adress
+      // don't begin with 0x, the adress in not interpreted as
+      // hexadecimal but alpha numeric value, then the conversion to
+      // integer make us pointing to an non allowed memory block => Crash.
+      std::ostringstream maskstream;
+      maskstream << "MEM:::"
+            <<  "DATAPOINTER=" << (unsigned long)(this->GetInputMask()->GetBufferPointer()) << ","
+            <<  "PIXELS=" << size[0] << ","
+            <<  "LINES=" << size[1] << ","
+            <<  "BANDS=" << nbBands << ","
+            <<  "DATATYPE=" << GDALGetDataTypeName(GdalDataTypeBridge::GetGDALDataType<InputPixelType>()) << ","
+            <<  "PIXELOFFSET=" << bytePerPixel * nbBands << ","
+            <<  "LINEOFFSET=" << bytePerPixel * nbBands * size[0] << ","
+            <<  "BANDOFFSET=" << bytePerPixel;
+      
+      GDALDataset * maskDataset = static_cast<GDALDataset *> (GDALOpen(maskstream.str().c_str(), GA_ReadOnly));
+      
+      //Set input Projection ref and Geo transform to the dataset.
+      maskDataset->SetProjection(this->GetInputMask()->GetProjectionRef().c_str());
+      
+      projSize = this->GetInputMask()->GetGeoTransform().size();
+      
+      //Set the geo transform of the input mask image (if any)
+      // Reporting origin and spacing of the buffered region
+      // the spacing is unchanged, the origin is relative to the buffered region
+      bufferIndexOrigin = this->GetInputMask()->GetBufferedRegion().GetIndex();
+      this->GetInputMask()->TransformIndexToPhysicalPoint(bufferIndexOrigin, bufferOrigin);
+      geoTransform[0] = bufferOrigin[0];
+      geoTransform[3] = bufferOrigin[1];
+      geoTransform[1] = this->GetInput()->GetSpacing()[0];
+      geoTransform[5] = this->GetInput()->GetSpacing()[1];
+      // FIXME: Here component 1 and 4 should be replaced by the orientation parameters
+      if (projSize == 0)
+      {
+         geoTransform[2] = 0.;
+         geoTransform[4] = 0.;
+      }
+      else
+      {
+         geoTransform[2] = this->GetInput()->GetGeoTransform()[2];
+         geoTransform[4] = this->GetInput()->GetGeoTransform()[4];
+      }
+      maskDataset->SetGeoTransform(geoTransform);
+      
+      GDALPolygonize(dataset->GetRasterBand(1), maskDataset->GetRasterBand(1), &outputLayer.ogr(), 0, options, NULL, NULL);
+      GDALClose(maskDataset);
+    }
+    else
+    {
+      GDALPolygonize(dataset->GetRasterBand(1), NULL, &outputLayer.ogr(), 0, options, NULL, NULL);
+    }
     
     this->SetNthOutput(0,ogrDS);
     
diff --git a/Code/OBIA/otbPersistentImageToOGRDataFilter.txx b/Code/OBIA/otbPersistentImageToOGRDataFilter.txx
index ccaf2b7435..5abec263dd 100644
--- a/Code/OBIA/otbPersistentImageToOGRDataFilter.txx
+++ b/Code/OBIA/otbPersistentImageToOGRDataFilter.txx
@@ -33,6 +33,8 @@ template<class TImage>
 PersistentImageToOGRDataFilter<TImage>
 ::PersistentImageToOGRDataFilter() : m_FieldName("DN"), m_LayerName("Layer"), m_GeometryType(wkbMultiPolygon)
 {
+   this->SetNumberOfInputs(2);
+   this->SetNumberOfRequiredInputs(2);
    m_StreamSize.Fill(0);
 }
 
diff --git a/Code/OBIA/otbStreamingVectorizedSegmentationOGR.h b/Code/OBIA/otbStreamingVectorizedSegmentationOGR.h
index 36bb5a81fd..a683793f77 100644
--- a/Code/OBIA/otbStreamingVectorizedSegmentationOGR.h
+++ b/Code/OBIA/otbStreamingVectorizedSegmentationOGR.h
@@ -133,6 +133,9 @@ public:
   
   itkSetMacro(Use8Connected, bool);
   itkGetMacro(Use8Connected, bool);
+  
+  virtual void SetInputMask(const LabelImageType *mask);
+  virtual const LabelImageType * GetInputMask(void);
 
 protected:
   PersistentStreamingLabelImageToOGRDataFilter();
@@ -181,6 +184,8 @@ public:
   typedef TImageType                               InputImageType;
   typedef typename PersistentStreamingLabelImageToOGRDataFilter<TImageType,
                TSegmentationFilter>::LabelPixelType                           LabelPixelType;
+  typedef typename PersistentStreamingLabelImageToOGRDataFilter<TImageType,
+               TSegmentationFilter>::LabelImageType                           LabelImageType;
   typedef typename PersistentStreamingLabelImageToOGRDataFilter<TImageType,
                TSegmentationFilter>::OGRDataSourcePointerType                 OGRDataSourcePointerType;
 
@@ -193,6 +198,15 @@ public:
     return this->GetFilter()->GetInput();
   }
   
+  void SetInputMask(LabelImageType * mask)
+  {
+    this->GetFilter()->SetInputMask(mask);
+  }
+  const LabelImageType * GetInputMask()
+  {
+    return this->GetFilter()->GetInputMask();
+  }
+  
   void SetOGRDataSource( OGRDataSourcePointerType ogrDS )
   {
     this->GetFilter()->SetOGRDataSource(ogrDS);
diff --git a/Code/OBIA/otbStreamingVectorizedSegmentationOGR.txx b/Code/OBIA/otbStreamingVectorizedSegmentationOGR.txx
index 819bd26a6e..a37796b99e 100644
--- a/Code/OBIA/otbStreamingVectorizedSegmentationOGR.txx
+++ b/Code/OBIA/otbStreamingVectorizedSegmentationOGR.txx
@@ -35,6 +35,8 @@ template <class TImageType, class TSegmentationFilter>
 PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
 ::PersistentStreamingLabelImageToOGRDataFilter() : m_TileMaxLabel(0), m_StartLabel(0), m_Use8Connected(false)
 {
+   this->SetNumberOfInputs(3);
+   this->SetNumberOfRequiredInputs(2);
    m_SegmentationFilter = SegmentationFilterType::New();
    m_TileNumber = 1;
 }
@@ -45,6 +47,28 @@ PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
 {
 }
 
+template <class TImageType, class TSegmentationFilter>
+void
+PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
+::SetInputMask(const LabelImageType *mask)
+{
+  this->itk::ProcessObject::SetNthInput(2, const_cast<LabelImageType *>(mask));
+}
+
+template <class TImageType, class TSegmentationFilter>
+const typename PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
+::LabelImageType *
+PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
+::GetInputMask(void)
+{
+  if (this->GetNumberOfInputs() < 3)
+    {
+    return 0;
+    }
+
+  return static_cast<const LabelImageType *>(this->itk::ProcessObject::GetInput(2));
+}
+
 template <class TImageType, class TSegmentationFilter>
 void
 PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
@@ -63,18 +87,12 @@ PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
   itk::TimeProbe tileChrono;
   tileChrono.Start();
   
-  
-  itk::TimeProbe chrono;
-  chrono.Start();
   // Apply an ExtractImageFilter to avoid problems with filters asking for the LargestPossibleRegion
   typedef itk::ExtractImageFilter<InputImageType, InputImageType> ExtractImageFilterType;
   typename ExtractImageFilterType::Pointer extract = ExtractImageFilterType::New();
   extract->SetInput( this->GetInput() );
   extract->SetExtractionRegion( this->GetInput()->GetRequestedRegion() );
   extract->Update();
-  
-  chrono.Stop();
-  //std::cout<< "extract took " << chrono.GetTotal() << " sec"<<std::endl;
 
   // WARNING: itk::ExtractImageFilter does not copy the MetadataDictionnary
   extract->GetOutput()->SetMetaDataDictionary(this->GetInput()->GetMetaDataDictionary());
@@ -84,7 +102,6 @@ PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
   typename LabelImageToOGRDataSourceFilterType::Pointer labelImageToOGRDataFilter =
                                               LabelImageToOGRDataSourceFilterType::New();
   
-  
   itk::TimeProbe chrono1;
   chrono1.Start();
   m_SegmentationFilter->SetInput(extract->GetOutput());
@@ -97,7 +114,22 @@ PersistentStreamingLabelImageToOGRDataFilter<TImageType, TSegmentationFilter>
   
   itk::TimeProbe chrono2;
   chrono2.Start();
-
+  
+  typename LabelImageType::ConstPointer inputMask = this->GetInputMask();
+  if (!inputMask.IsNull())
+  {
+     // Apply an ExtractImageFilter to avoid problems with filters asking for the LargestPossibleRegion
+     typedef itk::ExtractImageFilter<LabelImageType, LabelImageType> ExtractLabelImageFilterType;
+     typename ExtractLabelImageFilterType::Pointer maskExtract = ExtractLabelImageFilterType::New();
+     maskExtract->SetInput( this->GetInputMask() );
+     maskExtract->SetExtractionRegion( this->GetInputMask()->GetRequestedRegion() );
+     maskExtract->Update();
+      
+     // WARNING: itk::ExtractImageFilter does not copy the MetadataDictionnary
+     maskExtract->GetOutput()->SetMetaDataDictionary(this->GetInputMask()->GetMetaDataDictionary());
+     
+     labelImageToOGRDataFilter->SetInputMask(maskExtract->GetOutput());
+  }
   labelImageToOGRDataFilter->SetInput(dynamic_cast<LabelImageType *>(m_SegmentationFilter->GetOutputs().at(labelImageIndex).GetPointer()));
   labelImageToOGRDataFilter->SetFieldName(this->GetFieldName());
   labelImageToOGRDataFilter->SetUse8Connected(m_Use8Connected);
diff --git a/Testing/Code/OBIA/CMakeLists.txt b/Testing/Code/OBIA/CMakeLists.txt
index 0875354b42..8507eb8667 100644
--- a/Testing/Code/OBIA/CMakeLists.txt
+++ b/Testing/Code/OBIA/CMakeLists.txt
@@ -251,7 +251,8 @@ ADD_TEST(obTvStreamingVectorizedSegmentationOGR ${OBIA_TESTS1}
      #${TEMP}/obTvStreamingVectorizedSegmentationOutput.sqlite
      otbStreamingVectorizedSegmentationOGR
      ${INPUTDATA}/QB_Toulouse_Ortho_PAN.tif
-     ${TEMP}/obTvStreamingVectorizedSegmentationOGR.shp
+     ${INPUTDATA}/QB_Toulouse_Ortho_PAN_Mask.tif
+     ${TEMP}/obTvStreamingVectorizedSegmentationOGR.sqlite
      NewLayer
      100
      5
diff --git a/Testing/Code/OBIA/otbStreamingVectorizedSegmentationOGR.cxx b/Testing/Code/OBIA/otbStreamingVectorizedSegmentationOGR.cxx
index 52542f7f48..2f5aebcf54 100644
--- a/Testing/Code/OBIA/otbStreamingVectorizedSegmentationOGR.cxx
+++ b/Testing/Code/OBIA/otbStreamingVectorizedSegmentationOGR.cxx
@@ -49,20 +49,21 @@ int otbStreamingVectorizedSegmentationOGRNew(int argc, char * argv[])
 int otbStreamingVectorizedSegmentationOGR(int argc, char * argv[])
 {
 
-  if (argc != 8)
+  if (argc != 9)
     {
       std::cerr << "Usage: " << argv[0];
-      std::cerr << " inputImage outputVec layerName TileDimension spatialRadius rangeRadius minObjectSize" << std::endl;
+      std::cerr << " inputImage maskImage outputVec layerName TileDimension spatialRadius rangeRadius minObjectSize" << std::endl;
       return EXIT_FAILURE;
     }
 
   const char * imageName                    = argv[1];
-  const char * dataSourceName               = argv[2];
-  const char * layerName                    = argv[3];
-  const unsigned int tileSize               = atoi(argv[4]);
-  const unsigned int spatialRadiusOldMS     = atoi(argv[5]);
-  const double rangeRadiusOldMS             = atof(argv[6]);
-  const unsigned int minimumObjectSizeOldMS = atoi(argv[7]);
+  const char * maskName                     = argv[2];
+  const char * dataSourceName               = argv[3];
+  const char * layerName                    = argv[4];
+  const unsigned int tileSize               = atoi(argv[5]);
+  const unsigned int spatialRadiusOldMS     = atoi(argv[6]);
+  const double rangeRadiusOldMS             = atof(argv[7]);
+  const unsigned int minimumObjectSizeOldMS = atoi(argv[8]);
 
 
   typedef float InputPixelType;
@@ -77,16 +78,22 @@ int otbStreamingVectorizedSegmentationOGR(int argc, char * argv[])
   typedef otb::MeanShiftVectorImageFilter<ImageType, ImageType, LabelImageType> SegmentationFilterType;
   typedef otb::StreamingVectorizedSegmentationOGR<ImageType, SegmentationFilterType> StreamingVectorizedSegmentationOGRType;
   typedef otb::ImageFileReader<ImageType>                      ReaderType;
+  typedef otb::ImageFileReader<LabelImageType>                 MaskReaderType;
 
   ReaderType::Pointer             reader = ReaderType::New();
+  MaskReaderType::Pointer         maskReader = MaskReaderType::New();
   StreamingVectorizedSegmentationOGRType::Pointer filter = StreamingVectorizedSegmentationOGRType::New();  
 
   reader->SetFileName(imageName);
   reader->UpdateOutputInformation();
   
+  maskReader->SetFileName(maskName);
+  maskReader->UpdateOutputInformation();
+  
   otb::ogr::DataSource::Pointer ogrDS = otb::ogr::DataSource::New(dataSourceName, otb::ogr::DataSource::Modes::write);
 
   filter->SetInput(reader->GetOutput());
+  filter->SetInputMask(maskReader->GetOutput());
   filter->SetOGRDataSource(ogrDS);
   //filter->GetStreamer()->SetNumberOfLinesStrippedStreaming(atoi(argv[3]));
   filter->GetStreamer()->SetTileDimensionTiledStreaming(tileSize);
-- 
GitLab