Commit 56729071 authored by Jordi Inglada's avatar Jordi Inglada
Browse files

ENH: Enable the use of a functor for on the fly feature extraction

parent 77f80161
......@@ -26,6 +26,16 @@
namespace otb
{
template <typename TPixel>
struct DummyFexFunctor
{
TPixel operator()(const TPixel& p)
{
return p;
}
};
/**
* \class PersistentImageSampleExtractorFilter
*
......@@ -33,37 +43,40 @@ namespace otb
*
* \ingroup OTBSampling
*/
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor=
DummyFexFunctor<typename TInputImage::PixelType> >
class ITK_EXPORT PersistentImageSampleExtractorFilter :
public PersistentSamplingFilterBase<TInputImage>
{
public:
/** Standard Self typedef */
typedef PersistentImageSampleExtractorFilter Self;
typedef PersistentSamplingFilterBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
public PersistentSamplingFilterBase<TInputImage>
{
public:
/** Standard Self typedef */
typedef PersistentImageSampleExtractorFilter Self;
typedef PersistentSamplingFilterBase<TInputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
typedef TInputImage InputImageType;
typedef typename InputImageType::Pointer InputImagePointer;
typedef typename InputImageType::RegionType RegionType;
typedef typename InputImageType::PointType PointType;
typedef typename InputImageType::IndexType IndexType;
typedef typename InputImageType::PixelType PixelType;
typedef typename InputImageType::InternalPixelType InternalPixelType;
typedef TInputImage InputImageType;
typedef typename InputImageType::Pointer InputImagePointer;
typedef typename InputImageType::RegionType RegionType;
typedef typename InputImageType::PointType PointType;
typedef typename InputImageType::IndexType IndexType;
typedef typename InputImageType::PixelType PixelType;
typedef typename InputImageType::InternalPixelType InternalPixelType;
typedef ogr::DataSource OGRDataType;
typedef ogr::DataSource::Pointer OGRDataPointer;
typedef TFeatureExtractionFunctor FeatureExtractionFunctorType;
typedef itk::DataObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;
typedef ogr::DataSource OGRDataType;
typedef ogr::DataSource::Pointer OGRDataPointer;
/** Method for creation through the object factory. */
itkNewMacro(Self);
typedef itk::DataObject::DataObjectPointerArraySizeType DataObjectPointerArraySizeType;
/** Method for creation through the object factory. */
itkNewMacro(Self);
/** Runtime information support. */
itkTypeMacro(PersistentImageSampleExtractorFilter, PersistentSamplingFilterBase);
/** Runtime information support. */
itkTypeMacro(PersistentImageSampleExtractorFilter, PersistentSamplingFilterBase);
/** Set the output samples OGR container
/** Set the output samples OGR container
* (shall be equal to the input container for an 'update' mode) */
void SetOutputSamples(ogr::DataSource* data);
......@@ -84,6 +97,27 @@ public:
/** Get the sample names */
const std::vector<std::string> & GetSampleFieldNames();
/** Get the functor object. The functor is returned by reference.
* (Functors do not have to derive from itk::LightObject, so they do
* not necessarily have a reference count. So we cannot return a
* SmartPointer.) */
FeatureExtractionFunctorType& GetFunctor()
{
return m_FeatureExtraction;
}
/** Set the functor object. This replaces the current Functor with a
* copy of the specified Functor. This allows the user to specify a
* functor that has ivars set differently than the default functor.
* This method requires an operator!=() be defined on the functor
* (or the compiler's default implementation of operator!=() being
* appropriate). */
void SetFunctor(const FeatureExtractionFunctorType& functor)
{
m_FeatureExtraction = functor;
this->Modified();
}
protected:
/** Constructor */
PersistentImageSampleExtractorFilter();
......@@ -110,7 +144,10 @@ private:
/** List of field names for each component */
std::vector<std::string> m_SampleFieldNames;
};
/** Feature extraction function */
FeatureExtractionFunctorType m_FeatureExtraction;
};
/**
* \class ImageSampleExtractorFilter
......@@ -121,9 +158,11 @@ private:
*
* \ingroup OTBSampling
*/
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor=
DummyFexFunctor<typename TInputImage::PixelType> >
class ITK_EXPORT ImageSampleExtractorFilter :
public PersistentFilterStreamingDecorator<PersistentImageSampleExtractorFilter<TInputImage> >
public PersistentFilterStreamingDecorator<PersistentImageSampleExtractorFilter<TInputImage,
TFeatureExtractionFunctor> >
{
public:
/** Standard Self typedef */
......@@ -135,6 +174,7 @@ public:
typedef itk::SmartPointer<const Self> ConstPointer;
typedef TInputImage InputImageType;
typedef TFeatureExtractionFunctor FeatureExtractionFunctorType;
typedef otb::ogr::DataSource OGRDataType;
typedef typename Superclass::FilterType FilterType;
......
......@@ -21,13 +21,14 @@
#include "otbImageSampleExtractorFilter.h"
#include "itkDefaultConvertPixelTraits.h"
#include "itkProgressReporter.h"
#include <type_traits>
namespace otb
{
// --------- otb::PersistentImageSampleExtractorFilter ---------------------
template<class TInputImage>
PersistentImageSampleExtractorFilter<TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::PersistentImageSampleExtractorFilter() :
m_SampleFieldPrefix(std::string("band_"))
{
......@@ -35,17 +36,17 @@ PersistentImageSampleExtractorFilter<TInputImage>
this->SetNthOutput(0,TInputImage::New());
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetOutputSamples(ogr::DataSource* data)
{
this->SetNthOutput(1,data);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
ogr::DataSource*
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetOutputSamples()
{
if (this->GetNumberOfOutputs() < 2)
......@@ -55,18 +56,18 @@ PersistentImageSampleExtractorFilter<TInputImage>
return static_cast<ogr::DataSource *>(this->itk::ProcessObject::GetOutput(1));
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::Synthetize(void)
{
// clear temporary outputs
this->m_InMemoryOutputs.clear();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::Reset(void)
{
// Check output field names
......@@ -78,7 +79,7 @@ PersistentImageSampleExtractorFilter<TInputImage>
if ( m_SampleFieldNames.size() != nbBand)
{
itkExceptionMacro(<< "Wrong number of field names given, got "
<<m_SampleFieldNames.size() << ", expected "<< nbBand);
<<m_SampleFieldNames.size() << ", expected "<< nbBand);
}
}
else
......@@ -101,9 +102,9 @@ PersistentImageSampleExtractorFilter<TInputImage>
this->InitializeOutputDataSource(inputDS,output);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetSampleFieldNames(std::vector<std::string> &names)
{
m_SampleFieldNames.clear();
......@@ -113,17 +114,17 @@ PersistentImageSampleExtractorFilter<TInputImage>
}
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
const std::vector<std::string> &
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetSampleFieldNames()
{
return m_SampleFieldNames;
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GenerateOutputInformation()
{
Superclass::GenerateOutputInformation();
......@@ -151,9 +152,9 @@ PersistentImageSampleExtractorFilter<TInputImage>
}
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GenerateInputRequestedRegion()
{
InputImageType *input = const_cast<InputImageType*>(this->GetInput());
......@@ -162,9 +163,9 @@ PersistentImageSampleExtractorFilter<TInputImage>
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::ThreadedGenerateData(const RegionType&, itk::ThreadIdType threadid)
{
// Retrieve inputs
......@@ -194,44 +195,44 @@ PersistentImageSampleExtractorFilter<TInputImage>
{
case wkbPoint:
case wkbPoint25D:
{
OGRPoint* castPoint = dynamic_cast<OGRPoint*>(geom);
if (castPoint == NULL)
{
OGRPoint* castPoint = dynamic_cast<OGRPoint*>(geom);
if (castPoint == NULL)
{
// Wrong Type !
break;
}
imgPoint[0] = castPoint->getX();
imgPoint[1] = castPoint->getY();
inputImage->TransformPhysicalPointToIndex(imgPoint,imgIndex);
imgPixel = inputImage->GetPixel(imgIndex);
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
dstFeature.SetFrom( *featIt, TRUE );
dstFeature.SetFID(featIt->GetFID());
for (unsigned int i=0 ; i<nbBand ; ++i)
{
imgComp = static_cast<double>(itk::DefaultConvertPixelTraits<PixelType>::GetNthComponent(i,imgPixel));
// Fill the output OGRDataSource
dstFeature[m_SampleFieldNames[i]].SetValue(imgComp);
}
outputLayer.CreateFeature( dstFeature );
// Wrong Type !
break;
}
default:
imgPoint[0] = castPoint->getX();
imgPoint[1] = castPoint->getY();
inputImage->TransformPhysicalPointToIndex(imgPoint,imgIndex);
imgPixel = m_FeatureExtraction(inputImage->GetPixel(imgIndex));
ogr::Feature dstFeature(outputLayer.GetLayerDefn());
dstFeature.SetFrom( *featIt, TRUE );
dstFeature.SetFID(featIt->GetFID());
for (unsigned int i=0 ; i<nbBand ; ++i)
{
otbWarningMacro("Geometry not handled: " << geom->getGeometryName());
break;
imgComp = static_cast<double>(itk::DefaultConvertPixelTraits<PixelType>::GetNthComponent(i,imgPixel));
// Fill the output OGRDataSource
dstFeature[m_SampleFieldNames[i]].SetValue(imgComp);
}
outputLayer.CreateFeature( dstFeature );
break;
}
default:
{
otbWarningMacro("Geometry not handled: " << geom->getGeometryName());
break;
}
}
progress.CompletedPixel();
}
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
PersistentImageSampleExtractorFilter<TInputImage>
PersistentImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::InitializeFields()
{
this->ClearAdditionalFields();
......@@ -243,114 +244,114 @@ PersistentImageSampleExtractorFilter<TInputImage>
// -------------- otb::ImageSampleExtractorFilter --------------------------
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetInput(const TInputImage* image)
{
this->GetFilter()->SetInput(image);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
const TInputImage*
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetInput()
{
return this->GetFilter()->GetInput();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetSamplePositions(const otb::ogr::DataSource* data)
{
this->GetFilter()->SetOGRData(data);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
const otb::ogr::DataSource*
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetSamplePositions()
{
return this->GetFilter()->GetOGRData();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetOutputSamples(OGRDataType::Pointer data)
{
this->GetFilter()->SetOutputSamples(data);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
const otb::ogr::DataSource*
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetOutputSamples()
{
return this->GetFilter()->GetOutputSamples();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetOutputFieldPrefix(const std::string &key)
{
this->GetFilter()->SetSampleFieldPrefix(key);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
std::string
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetOutputFieldPrefix()
{
return this->GetFilter()->GetSampleFieldPrefix();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetOutputFieldNames(std::vector<std::string> &names)
{
this->GetFilter()->SetSampleFieldNames(names);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
const std::vector<std::string> &
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetOutputFieldNames()
{
return this->GetFilter()->GetSampleFieldNames();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetLayerIndex(int index)
{
this->GetFilter()->SetLayerIndex(index);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
int
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetLayerIndex()
{
return this->GetFilter()->GetLayerIndex();
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
void
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::SetClassFieldName(const std::string &name)
{
this->GetFilter()->SetFieldName(name);
}
template<class TInputImage>
template<class TInputImage, class TFeatureExtractionFunctor>
std::string
ImageSampleExtractorFilter<TInputImage>
ImageSampleExtractorFilter<TInputImage,TFeatureExtractionFunctor>
::GetClassFieldName(void)
{
return this->GetFilter()->GetFieldName();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment