Commit b6eab4d3 authored by Guillaume Pasero's avatar Guillaume Pasero

ENH: factorize code using the base sampling class

parent 13f04772
......@@ -18,9 +18,10 @@
#ifndef __otbImageSampleExtractorFilter_h
#define __otbImageSampleExtractorFilter_h
#include "otbPersistentImageToOGRDataFilter.h"
#include "otbPersistentSamplingFilterBase.h"
#include "otbPersistentFilterStreamingDecorator.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbImage.h"
namespace otb
{
......@@ -34,12 +35,12 @@ namespace otb
*/
template<class TInputImage>
class ITK_EXPORT PersistentImageSampleExtractorFilter :
public PersistentImageToOGRDataFilter<TInputImage>
public PersistentSamplingFilterBase<TInputImage>
{
public:
/** Standard Self typedef */
typedef PersistentImageSampleExtractorFilter Self;
typedef PersistentImageToOGRDataFilter<TInputImage> Superclass;
typedef PersistentSamplingFilterBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
......@@ -51,8 +52,8 @@ public:
typedef typename InputImageType::PixelType PixelType;
typedef typename InputImageType::InternalPixelType InternalPixelType;
typedef typename Superclass::OGRDataSourceType OGRDataSourceType;
typedef typename Superclass::OGRDataSourcePointerType OGRDataSourcePointerType;
typedef ogr::DataSource OGRDataType;
typedef ogr::DataSource::Pointer OGRDataPointer;
typedef itk::DataObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;
......@@ -60,11 +61,15 @@ public:
itkNewMacro(Self);
/** Runtime information support. */
itkTypeMacro(PersistentImageSampleExtractorFilter, PersistentImageToOGRDataFilter);
itkTypeMacro(PersistentImageSampleExtractorFilter, PersistentSamplingFilterBase);
void SetSamplePositions(const otb::ogr::DataSource* vector);
const otb::ogr::DataSource* GetSamplePositions();
/** Set the output samples OGR container
* (shall be equal to the input container for an 'update' mode) */
void SetOutputSamples(ogr::DataSource* data);
/** Get the output samples OGR container */
ogr::DataSource* GetOutputSamples();
virtual void Synthetize(void);
/** Reset method called before starting the streaming*/
......@@ -72,9 +77,6 @@ public:
itkSetMacro(SampleFieldPrefix, std::string);
itkGetMacro(SampleFieldPrefix, std::string);
itkSetMacro(LayerIndex, int);
itkGetMacro(LayerIndex, int);
protected:
/** Constructor */
......@@ -82,21 +84,18 @@ protected:
/** Destructor */
virtual ~PersistentImageSampleExtractorFilter() {}
virtual void GenerateInputRequestedRegion();
/** process only points */
virtual void ThreadedGenerateData(const RegionType&, itk::ThreadIdType threadid);
private:
PersistentImageSampleExtractorFilter(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
virtual OGRDataSourcePointerType ProcessTile();
/** Apply a spatial filtering on the OGRDataSource corresponding to the processed tile */
void ApplyPolygonsSpatialFilter();
/** Initialize fields to store extracted values (Real type) */
void InitializeFields(ogr::Layer &layer, unsigned int size);
void InitializeFields();
/** Layer to use in the shape file, default to 0 */
int m_LayerIndex;
/** Prefix to generate field names for each input channel */
std::string m_SampleFieldPrefix;
};
......@@ -142,8 +141,8 @@ public:
void SetSamplePositions(const otb::ogr::DataSource* data);
const otb::ogr::DataSource* GetSamplePositions();
void SetOutputOGRData(OGRDataType::Pointer data);
const otb::ogr::DataSource* GetOutputOGRData();
void SetOutputSamples(OGRDataType::Pointer data);
const otb::ogr::DataSource* GetOutputSamples();
void SetOutputFieldPrefix(const std::string &key);
std::string GetOutputFieldPrefix();
......
......@@ -19,6 +19,7 @@
#define __otbImageSampleExtractorFilter_txx
#include "itkDefaultConvertPixelTraits.h"
#include "itkProgressReporter.h"
namespace otb
{
......@@ -27,34 +28,30 @@ namespace otb
template<class TInputImage>
PersistentImageSampleExtractorFilter<TInputImage>
::PersistentImageSampleExtractorFilter() :
m_LayerIndex(0),
m_SampleFieldPrefix(std::string("band_"))
{
this->SetNumberOfRequiredInputs(3);
this->SetNumberOfRequiredOutputs(1);
this->SetNumberOfRequiredOutputs(2);
this->SetNthOutput(0,TInputImage::New());
this->SetGeometryType(wkbPoint);
}
template<class TInputImage>
void
PersistentImageSampleExtractorFilter<TInputImage>
::SetSamplePositions(const otb::ogr::DataSource* vector)
::SetOutputSamples(ogr::DataSource* data)
{
this->SetNthInput(2, const_cast<otb::ogr::DataSource *>( vector ));
this->SetNthOutput(1,data);
}
template<class TInputImage>
const otb::ogr::DataSource*
ogr::DataSource*
PersistentImageSampleExtractorFilter<TInputImage>
::GetSamplePositions()
::GetOutputSamples()
{
if (this->GetNumberOfInputs()<3)
if (this->GetNumberOfOutputs() < 2)
{
return 0;
}
return static_cast<const ogr::DataSource *>(this->itk::ProcessObject::GetInput(2));
return static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(1));
}
template<class TInputImage>
......@@ -62,8 +59,8 @@ void
PersistentImageSampleExtractorFilter<TInputImage>
::Synthetize(void)
{
ogr::DataSource* vectors = const_cast<ogr::DataSource*>(this->GetSamplePositions());
vectors->GetLayer(m_LayerIndex).SetSpatialFilter(NULL);
// clear temporary outputs
this->m_InMemoryOutputs.clear();
}
template<class TInputImage>
......@@ -71,92 +68,54 @@ void
PersistentImageSampleExtractorFilter<TInputImage>
::Reset(void)
{
// initialize output DataSource if copy mode
const ogr::DataSource* inVectors = this->GetSamplePositions();
OGRFeatureDefn &inLayerDefn = inVectors->GetLayer(m_LayerIndex).GetLayerDefn();
ogr::DataSource* outVectors = const_cast<ogr::DataSource*>(this->GetOGRDataSource());
bool updateMode = bool(inVectors == outVectors);
ogr::Layer outLayer = inVectors->GetLayer(m_LayerIndex); // has to be initialized with a real layer
if (updateMode)
{
outLayer = outVectors->GetLayer(m_LayerIndex);
if (!outLayer.ogr().TestCapability(OLCRandomWrite))
{
itkExceptionMacro(<< "Output layer doesn't support OLCRandomWrite.");
}
}
else
{
// use the same field type for class label
int inCFieldIndex = inLayerDefn.GetFieldIndex(this->GetFieldName().c_str());
this->SetFieldType(inLayerDefn.GetFieldDefn(inCFieldIndex)->GetType());
// Create layer
this->Initialize();
// Get created layer. Handle the case of shapefile, which is a layer and not a datasource.
//The layer name in a shapefile is the shapefile's name.
//This is not the case for a database as sqlite or PG.
if (outVectors->GetLayersCount() == 1)
{
outLayer = outVectors->GetLayer(0);
}
else
{
outLayer = outVectors->GetLayer(this->GetLayerName());
}
}
// initialize additional fields for output
this->InitializeFields();
// initialize fields in output DataSource
TInputImage* inputImage = const_cast<TInputImage*>(this->GetInput());
inputImage->UpdateOutputInformation();
unsigned int nbBand = inputImage->GetNumberOfComponentsPerPixel();
this->InitializeFields(outLayer,nbBand);
// initialize output DataSource
ogr::DataSource* inputDS = const_cast<ogr::DataSource*>(this->GetOGRData());
ogr::DataSource* output = this->GetOutputSamples();
this->InitializeOutputDataSource(inputDS,output);
}
template<class TInputImage>
void
PersistentImageSampleExtractorFilter<TInputImage>
::GenerateInputRequestedRegion()
{
InputImageType *input = const_cast<InputImageType*>(this->GetInput());
RegionType requested = this->GetOutput()->GetRequestedRegion();
input->SetRequestedRegion(requested);
}
template<class TInputImage>
typename PersistentImageSampleExtractorFilter<TInputImage>::OGRDataSourcePointerType
void
PersistentImageSampleExtractorFilter<TInputImage>
::ProcessTile()
::ThreadedGenerateData(const RegionType&, itk::ThreadIdType threadid)
{
// Retrieve inputs
TInputImage* inputImage = const_cast<TInputImage*>(this->GetInput());
TInputImage* outputImage = this->GetOutput();
ogr::DataSource* inVectors = const_cast<ogr::DataSource*>(this->GetSamplePositions());
ogr::DataSource* outVectors = const_cast<ogr::DataSource*>(this->GetOGRDataSource());
bool updateMode = bool(inVectors == outVectors);
unsigned int nbBand = inputImage->GetNumberOfComponentsPerPixel();
ogr::Layer layer = this->m_InMemoryInputs[threadid]->GetLayerChecked(0);
if (! layer)
{
return;
}
ogr::Layer outputLayer = this->m_InMemoryOutputs[threadid][0]->GetLayerChecked(0);
itk::ProgressReporter progress( this, threadid, layer.GetFeatureCount(true) );
// Loop across the features in the layer (filtered by requested region in BeforeTGD already)
OGRGeometry *geom;
PointType imgPoint;
IndexType imgIndex;
PixelType imgPixel;
double imgComp;
RegionType requestedRegion = outputImage->GetRequestedRegion();
unsigned int nbBand = inputImage->GetNumberOfComponentsPerPixel();
std::ostringstream oss;
std::string fieldName;
ogr::Layer inLayer = inVectors->GetLayer(m_LayerIndex);
// Apply spatial filter on input sample positions
this->ApplyPolygonsSpatialFilter();
float featCount = static_cast<float>(inLayer.GetFeatureCount(true));
if (featCount == 0.0) featCount=1.0;
int currentCount = 0;
// Prepare temporary output data source
OGRDataSourcePointerType tmpDS = ogr::DataSource::New();
OGRSpatialReference * oSRS = NULL;
if (inLayer.GetSpatialRef()) oSRS = inLayer.GetSpatialRef()->Clone();
ogr::Layer dstLayer = tmpDS->CreateLayer(
this->GetLayerName(),
oSRS,
this->GetGeometryType(),
this->GetOGRLayerCreationOptions());
OGRFieldDefn labelField(this->GetFieldName().c_str(),this->GetFieldType());
dstLayer.CreateField(labelField, true);
this->InitializeFields(dstLayer,nbBand);
// Loop across the features in the layer
OGRGeometry *geom;
otb::ogr::Layer::iterator featIt = inLayer.begin();
for(; featIt!=inLayer.end(); ++featIt)
ogr::Layer::const_iterator featIt = layer.begin();
for(; featIt!=layer.end(); ++featIt)
{
geom = featIt->ogr().GetGeometryRef();
switch (geom->getGeometryType())
......@@ -175,7 +134,7 @@ PersistentImageSampleExtractorFilter<TInputImage>
inputImage->TransformPhysicalPointToIndex(imgPoint,imgIndex);
imgPixel = inputImage->GetPixel(imgIndex);
typename Superclass::OGRFeatureType dstFeature(dstLayer.GetLayerDefn());
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
dstFeature.SetFrom( *featIt, TRUE );
dstFeature.SetFID(featIt->GetFID());
for (unsigned int i=0 ; i<nbBand ; ++i)
......@@ -187,7 +146,7 @@ PersistentImageSampleExtractorFilter<TInputImage>
// Fill the ouptut OGRDataSource
dstFeature[fieldName].SetValue(imgComp);
}
dstLayer.CreateFeature( dstFeature );
outputLayer.CreateFeature( dstFeature );
break;
}
default:
......@@ -196,79 +155,34 @@ PersistentImageSampleExtractorFilter<TInputImage>
break;
}
}
currentCount++;
this->UpdateProgress(static_cast<float>(currentCount)/featCount);
// TODO : multi-points ?
progress.CompletedPixel();
}
if (updateMode)
{
inLayer.ogr().StartTransaction();
otb::ogr::Layer::iterator outIt = dstLayer.begin();
for (; outIt!=dstLayer.end(); ++outIt)
{
inLayer.SetFeature( *outIt );
}
const OGRErr err = inLayer.ogr().CommitTransaction();
if (err != OGRERR_NONE)
{
itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << inLayer.ogr().GetName() << ".");
}
// empty the output dataset
tmpDS->DeleteLayer(0);
tmpDS->CreateLayer(
this->GetLayerName(),
oSRS,
this->GetGeometryType(),
this->GetOGRLayerCreationOptions());
}
return tmpDS;
}
template<class TInputImage>
void
PersistentImageSampleExtractorFilter<TInputImage>
::ApplyPolygonsSpatialFilter()
{
TInputImage* outputImage = this->GetOutput();
otb::ogr::DataSource* vectors = const_cast<otb::ogr::DataSource*>(this->GetSamplePositions());
const RegionType& requestedRegion = outputImage->GetRequestedRegion();
typename TInputImage::IndexType startIndex = requestedRegion.GetIndex();
typename TInputImage::IndexType endIndex = requestedRegion.GetUpperIndex();
itk::Point<double, 2> startPoint;
itk::Point<double, 2> endPoint;
outputImage->TransformIndexToPhysicalPoint(startIndex, startPoint);
outputImage->TransformIndexToPhysicalPoint(endIndex, endPoint);
vectors->GetLayer(m_LayerIndex).SetSpatialFilterRect(
std::min(startPoint[0],endPoint[0]),
std::min(startPoint[1],endPoint[1]),
std::max(startPoint[0],endPoint[0]),
std::max(startPoint[1],endPoint[1]));
}
template<class TInputImage>
void
PersistentImageSampleExtractorFilter<TInputImage>
::InitializeFields(ogr::Layer &layer, unsigned int size)
::InitializeFields()
{
OGRFeatureDefn &outFeatureDefn = layer.GetLayerDefn();
TInputImage* inputImage = const_cast<TInputImage*>(this->GetInput());
inputImage->UpdateOutputInformation();
unsigned int nbBand = inputImage->GetNumberOfComponentsPerPixel();
this->m_AdditionalFields.clear();
std::ostringstream oss;
std::string fieldName;
for (unsigned int i=0 ; i<size ; ++i)
for (unsigned int i=0 ; i<nbBand ; ++i)
{
oss.str(std::string(""));
oss << this->GetSampleFieldPrefix() << i;
fieldName = oss.str();
if (outFeatureDefn.GetFieldIndex(fieldName.c_str()) < 0)
{
OGRFieldDefn sampleField(fieldName.c_str(),OFTReal);
layer.CreateField(sampleField, true);
}
// DEBUG
std::cout << "new field "<< fieldName << std::endl;
OGRFieldDefn sampleField(fieldName.c_str(),OFTReal);
sampleField.SetWidth(12);
sampleField.SetPrecision(10);
this->m_AdditionalFields.push_back(ogr::FieldDefn(sampleField));
}
}
......@@ -295,7 +209,7 @@ void
ImageSampleExtractorFilter<TInputImage>
::SetSamplePositions(const otb::ogr::DataSource* data)
{
this->GetFilter()->SetSamplePositions(data);
this->GetFilter()->SetOGRData(data);
}
template<class TInputImage>
......@@ -303,23 +217,23 @@ const otb::ogr::DataSource*
ImageSampleExtractorFilter<TInputImage>
::GetSamplePositions()
{
return this->GetFilter()->GetSamplePositions();
return this->GetFilter()->GetOGRData();
}
template<class TInputImage>
void
ImageSampleExtractorFilter<TInputImage>
::SetOutputOGRData(OGRDataType::Pointer data)
::SetOutputSamples(OGRDataType::Pointer data)
{
this->GetFilter()->SetOGRDataSource(data);
this->GetFilter()->SetOutputSamples(data);
}
template<class TInputImage>
const otb::ogr::DataSource*
ImageSampleExtractorFilter<TInputImage>
::GetOutputOGRData()
::GetOutputSamples()
{
return this->GetFilter()->GetOGRDataSource();
return this->GetFilter()->GetOutputSamples();
}
template<class TInputImage>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment