From e935fdc721d1665bb1c9e0b2ff20c5b9d35c2609 Mon Sep 17 00:00:00 2001 From: Arnaud Jaen <arnaud.jaen@c-s.fr> Date: Fri, 22 Mar 2013 16:59:22 +0100 Subject: [PATCH] ENH: Add an ImageClassificationFilter class which classifies an image using a model (derived from otbMachineLearningModel) --- Code/Learning/otbImageClassificationFilter.h | 118 +++++++++++++++ .../Learning/otbImageClassificationFilter.txx | 137 ++++++++++++++++++ .../Learning/otbImageClassificationFilter.cxx | 73 ++++++++++ 3 files changed, 328 insertions(+) create mode 100644 Code/Learning/otbImageClassificationFilter.h create mode 100644 Code/Learning/otbImageClassificationFilter.txx create mode 100644 Testing/Code/Learning/otbImageClassificationFilter.cxx diff --git a/Code/Learning/otbImageClassificationFilter.h b/Code/Learning/otbImageClassificationFilter.h new file mode 100644 index 0000000000..4ac0f63f16 --- /dev/null +++ b/Code/Learning/otbImageClassificationFilter.h @@ -0,0 +1,118 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbImageClassificationFilter_h +#define __otbImageClassificationFilter_h + +#include "itkImageToImageFilter.h" +#include "otbMachineLearningModel.h" + +namespace otb +{ +/** \class ImageClassificationFilter + * \brief This filter performs the classification of a VectorImage using a Model. + * + * This filter is streamed and threaded, allowing to classify huge images + * while fully using several core. + * + * \sa Classifier + * \ingroup Streamed + * \ingroup Threaded + */ +template <class TInputImage, class TOutputImage, class TMaskImage = TOutputImage> +class ITK_EXPORT ImageClassificationFilter + : public itk::ImageToImageFilter<TInputImage, TOutputImage> +{ +public: + /** Standard typedefs */ + typedef ImageClassificationFilter Self; + typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Type macro */ + itkNewMacro(Self); + + /** Creation through object factory macro */ + itkTypeMacro(ImageClassificationFilter, ImageToImageFilter); + + typedef TInputImage InputImageType; + typedef typename InputImageType::ConstPointer InputImageConstPointerType; + typedef typename InputImageType::InternalPixelType ValueType; + + typedef TMaskImage MaskImageType; + typedef typename MaskImageType::ConstPointer MaskImageConstPointerType; + typedef typename MaskImageType::Pointer MaskImagePointerType; + + typedef TOutputImage OutputImageType; + typedef typename OutputImageType::Pointer OutputImagePointerType; + typedef typename OutputImageType::RegionType OutputImageRegionType; + typedef typename OutputImageType::PixelType LabelType; + + typedef MachineLearningModel<ValueType, LabelType> ModelType; + typedef typename ModelType::Pointer ModelPointerType; + + /** Set/Get the model */ + itkSetObjectMacro(Model, ModelType); + itkGetObjectMacro(Model, ModelType); + + /** Set/Get the default label */ + itkSetMacro(DefaultLabel, LabelType); + itkGetMacro(DefaultLabel, LabelType); + + /** + * If set, only pixels within the mask will be classified. + * All pixels with a value greater than 0 in the mask, will be classified. + * \param mask The input mask. + */ + void SetInputMask(const MaskImageType * mask); + + /** + * Get the input mask. + * \return The mask. + */ + const MaskImageType * GetInputMask(void); + +protected: + /** Constructor */ + ImageClassificationFilter(); + /** Destructor */ + virtual ~ImageClassificationFilter() {} + + /** Threaded generate data */ + virtual void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId); + /** Before threaded generate data */ + virtual void BeforeThreadedGenerateData(); + /**PrintSelf method */ + virtual void PrintSelf(std::ostream& os, itk::Indent indent) const; + +private: + ImageClassificationFilter(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + /** The model used for classification */ + ModelPointerType m_Model; + /** Default label for invalid pixels (when using a mask) */ + LabelType m_DefaultLabel; + +}; +} // End namespace otb +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbImageClassificationFilter.txx" +#endif + +#endif diff --git a/Code/Learning/otbImageClassificationFilter.txx b/Code/Learning/otbImageClassificationFilter.txx new file mode 100644 index 0000000000..26364cf0c3 --- /dev/null +++ b/Code/Learning/otbImageClassificationFilter.txx @@ -0,0 +1,137 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbImageClassificationFilter_txx +#define __otbImageClassificationFilter_txx + +#include "otbImageClassificationFilter.h" +#include "itkImageRegionIterator.h" +#include "itkProgressReporter.h" + +namespace otb +{ +/** + * Constructor + */ +template <class TInputImage, class TOutputImage, class TMaskImage> +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ImageClassificationFilter() +{ + this->SetNumberOfInputs(2); + this->SetNumberOfRequiredInputs(1); + m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue(); +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::SetInputMask(const MaskImageType * mask) +{ + this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType *>(mask)); +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +const typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::MaskImageType * +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::GetInputMask() +{ + if (this->GetNumberOfInputs() < 2) + { + return 0; + } + return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1)); +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::BeforeThreadedGenerateData() +{ + if (!m_Model) + { + itkGenericExceptionMacro(<< "No model for classification"); + } +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId) +{ + // Get the input pointers + InputImageConstPointerType inputPtr = this->GetInput(); + MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); + OutputImagePointerType outputPtr = this->GetOutput(); + + // Progress reporting + itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); + + // Define iterators + typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; + typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType; + typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; + + InputIteratorType inIt(inputPtr, outputRegionForThread); + OutputIteratorType outIt(outputPtr, outputRegionForThread); + + // Eventually iterate on masks + MaskIteratorType maskIt; + if (inputMaskPtr) + { + maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread); + maskIt.GoToBegin(); + } + + bool validPoint = true; + + // Walk the part of the image + for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt) + { + // Check pixel validity + if (inputMaskPtr) + { + validPoint = maskIt.Get() > 0; + ++maskIt; + } + // If point is valid + if (validPoint) + { + // Classifify + outIt.Set(m_Model->Predict(inIt.Get())[0]); + } + else + { + // else, set default value + outIt.Set(m_DefaultLabel); + } + progress.CompletedPixel(); + } + +} +/** + * PrintSelf Method + */ +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::PrintSelf(std::ostream& os, itk::Indent indent) const +{ + Superclass::PrintSelf(os, indent); +} +} // End namespace otb +#endif diff --git a/Testing/Code/Learning/otbImageClassificationFilter.cxx b/Testing/Code/Learning/otbImageClassificationFilter.cxx new file mode 100644 index 0000000000..0558b9f694 --- /dev/null +++ b/Testing/Code/Learning/otbImageClassificationFilter.cxx @@ -0,0 +1,73 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#include "otbImageClassificationFilter.h" +#include "otbVectorImage.h" +#include "otbImage.h" +#include "otbImageFileReader.h" +#include "otbImageFileWriter.h" +#include "otbMachineLearningModelFactory.h" + +const unsigned int Dimension = 2; +typedef double PixelType; +typedef unsigned short LabeledPixelType; + +typedef otb::VectorImage<PixelType, Dimension> ImageType; +typedef otb::Image<LabeledPixelType, Dimension> LabeledImageType; +typedef otb::ImageClassificationFilter<ImageType, LabeledImageType> ClassificationFilterType; +typedef ClassificationFilterType::ModelType ModelType; +typedef ClassificationFilterType::ValueType ValueType; +typedef ClassificationFilterType::LabelType LabelType; +typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; +typedef otb::ImageFileReader<ImageType> ReaderType; +typedef otb::ImageFileWriter<LabeledImageType> WriterType; + +int otbImageClassificationFilterNew(int argc, char * argv[]) +{ + ClassificationFilterType::Pointer filter = ClassificationFilterType::New(); + return EXIT_SUCCESS; +} + +int otbImageClassificationFilter(int argc, char * argv[]) +{ + const char * infname = argv[1]; + const char * modelfname = argv[2]; + const char * outfname = argv[3]; + + // Instantiating object + ClassificationFilterType::Pointer filter = ClassificationFilterType::New(); + + ReaderType::Pointer reader = ReaderType::New(); + reader->SetFileName(infname); + + ModelType::Pointer model; + + model = MachineLearningModelFactoryType::CreateMachineLearningModel(modelfname, + MachineLearningModelFactoryType::ReadMode); + + model->Load(modelfname); + + filter->SetModel(model); + filter->SetInput(reader->GetOutput()); + + WriterType::Pointer writer = WriterType::New(); + writer->SetInput(filter->GetOutput()); + writer->SetFileName(outfname); + writer->Update(); + + return EXIT_SUCCESS; +} -- GitLab