From e935fdc721d1665bb1c9e0b2ff20c5b9d35c2609 Mon Sep 17 00:00:00 2001
From: Arnaud Jaen <arnaud.jaen@c-s.fr>
Date: Fri, 22 Mar 2013 16:59:22 +0100
Subject: [PATCH] ENH: Add an ImageClassificationFilter class which classifies
 an image using a model (derived from otbMachineLearningModel)

---
 Code/Learning/otbImageClassificationFilter.h  | 118 +++++++++++++++
 .../Learning/otbImageClassificationFilter.txx | 137 ++++++++++++++++++
 .../Learning/otbImageClassificationFilter.cxx |  73 ++++++++++
 3 files changed, 328 insertions(+)
 create mode 100644 Code/Learning/otbImageClassificationFilter.h
 create mode 100644 Code/Learning/otbImageClassificationFilter.txx
 create mode 100644 Testing/Code/Learning/otbImageClassificationFilter.cxx

diff --git a/Code/Learning/otbImageClassificationFilter.h b/Code/Learning/otbImageClassificationFilter.h
new file mode 100644
index 0000000000..4ac0f63f16
--- /dev/null
+++ b/Code/Learning/otbImageClassificationFilter.h
@@ -0,0 +1,118 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef __otbImageClassificationFilter_h
+#define __otbImageClassificationFilter_h
+
+#include "itkImageToImageFilter.h"
+#include "otbMachineLearningModel.h"
+
+namespace otb
+{
+/** \class ImageClassificationFilter
+ *  \brief This filter performs the classification of a VectorImage using a Model.
+ *
+ *  This filter is streamed and threaded, allowing to classify huge images
+ *  while fully using several core.
+ *
+ * \sa Classifier
+ * \ingroup Streamed
+ * \ingroup Threaded
+ */
+template <class TInputImage, class TOutputImage, class TMaskImage = TOutputImage>
+class ITK_EXPORT ImageClassificationFilter
+  : public itk::ImageToImageFilter<TInputImage, TOutputImage>
+{
+public:
+  /** Standard typedefs */
+  typedef ImageClassificationFilter                       Self;
+  typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
+  typedef itk::SmartPointer<Self>                            Pointer;
+  typedef itk::SmartPointer<const Self>                      ConstPointer;
+
+  /** Type macro */
+  itkNewMacro(Self);
+
+  /** Creation through object factory macro */
+  itkTypeMacro(ImageClassificationFilter, ImageToImageFilter);
+
+  typedef TInputImage                                InputImageType;
+  typedef typename InputImageType::ConstPointer      InputImageConstPointerType;
+  typedef typename InputImageType::InternalPixelType ValueType;
+
+  typedef TMaskImage                           MaskImageType;
+  typedef typename MaskImageType::ConstPointer MaskImageConstPointerType;
+  typedef typename MaskImageType::Pointer      MaskImagePointerType;
+
+  typedef TOutputImage                         OutputImageType;
+  typedef typename OutputImageType::Pointer    OutputImagePointerType;
+  typedef typename OutputImageType::RegionType OutputImageRegionType;
+  typedef typename OutputImageType::PixelType  LabelType;
+
+  typedef MachineLearningModel<ValueType, LabelType> ModelType;
+  typedef typename ModelType::Pointer    ModelPointerType;
+
+  /** Set/Get the model */
+  itkSetObjectMacro(Model, ModelType);
+  itkGetObjectMacro(Model, ModelType);
+
+  /** Set/Get the default label */
+  itkSetMacro(DefaultLabel, LabelType);
+  itkGetMacro(DefaultLabel, LabelType);
+
+  /**
+   * If set, only pixels within the mask will be classified.
+   * All pixels with a value greater than 0 in the mask, will be classified.
+   * \param mask The input mask.
+   */
+  void SetInputMask(const MaskImageType * mask);
+
+  /**
+   * Get the input mask.
+   * \return The mask.
+   */
+  const MaskImageType * GetInputMask(void);
+
+protected:
+  /** Constructor */
+  ImageClassificationFilter();
+  /** Destructor */
+  virtual ~ImageClassificationFilter() {}
+
+  /** Threaded generate data */
+  virtual void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId);
+  /** Before threaded generate data */
+  virtual void BeforeThreadedGenerateData();
+  /**PrintSelf method */
+  virtual void PrintSelf(std::ostream& os, itk::Indent indent) const;
+
+private:
+  ImageClassificationFilter(const Self &); //purposely not implemented
+  void operator =(const Self&); //purposely not implemented
+
+  /** The model used for classification */
+  ModelPointerType m_Model;
+  /** Default label for invalid pixels (when using a mask) */
+  LabelType m_DefaultLabel;
+
+};
+} // End namespace otb
+#ifndef OTB_MANUAL_INSTANTIATION
+#include "otbImageClassificationFilter.txx"
+#endif
+
+#endif
diff --git a/Code/Learning/otbImageClassificationFilter.txx b/Code/Learning/otbImageClassificationFilter.txx
new file mode 100644
index 0000000000..26364cf0c3
--- /dev/null
+++ b/Code/Learning/otbImageClassificationFilter.txx
@@ -0,0 +1,137 @@
+/*=========================================================================
+
+  Program:   ORFEO Toolbox
+  Language:  C++
+  Date:      $Date$
+  Version:   $Revision$
+
+
+  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+  See OTBCopyright.txt for details.
+
+
+     This software is distributed WITHOUT ANY WARRANTY; without even
+     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+     PURPOSE.  See the above copyright notices for more information.
+
+=========================================================================*/
+#ifndef __otbImageClassificationFilter_txx
+#define __otbImageClassificationFilter_txx
+
+#include "otbImageClassificationFilter.h"
+#include "itkImageRegionIterator.h"
+#include "itkProgressReporter.h"
+
+namespace otb
+{
+/**
+ * Constructor
+ */
+template <class TInputImage, class TOutputImage, class TMaskImage>
+ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::ImageClassificationFilter()
+{
+  this->SetNumberOfInputs(2);
+  this->SetNumberOfRequiredInputs(1);
+  m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue();
+}
+
+template <class TInputImage, class TOutputImage, class TMaskImage>
+void
+ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::SetInputMask(const MaskImageType * mask)
+{
+  this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType *>(mask));
+}
+
+template <class TInputImage, class TOutputImage, class TMaskImage>
+const typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::MaskImageType *
+ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::GetInputMask()
+{
+  if (this->GetNumberOfInputs() < 2)
+    {
+    return 0;
+    }
+  return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1));
+}
+
+template <class TInputImage, class TOutputImage, class TMaskImage>
+void
+ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::BeforeThreadedGenerateData()
+{
+  if (!m_Model)
+    {
+    itkGenericExceptionMacro(<< "No model for classification");
+    }
+}
+
+template <class TInputImage, class TOutputImage, class TMaskImage>
+void
+ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId)
+{
+  // Get the input pointers
+  InputImageConstPointerType inputPtr     = this->GetInput();
+  MaskImageConstPointerType  inputMaskPtr  = this->GetInputMask();
+  OutputImagePointerType     outputPtr    = this->GetOutput();
+
+  // Progress reporting
+  itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels());
+
+  // Define iterators
+  typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType;
+  typedef itk::ImageRegionConstIterator<MaskImageType>  MaskIteratorType;
+  typedef itk::ImageRegionIterator<OutputImageType>     OutputIteratorType;
+
+  InputIteratorType inIt(inputPtr, outputRegionForThread);
+  OutputIteratorType outIt(outputPtr, outputRegionForThread);
+
+  // Eventually iterate on masks
+  MaskIteratorType maskIt;
+  if (inputMaskPtr)
+    {
+    maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread);
+    maskIt.GoToBegin();
+    }
+
+  bool validPoint = true;
+
+  // Walk the part of the image
+  for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt)
+    {
+    // Check pixel validity
+    if (inputMaskPtr)
+      {
+      validPoint = maskIt.Get() > 0;
+      ++maskIt;
+      }
+    // If point is valid
+    if (validPoint)
+      {
+      // Classifify
+      outIt.Set(m_Model->Predict(inIt.Get())[0]);
+      }
+    else
+      {
+      // else, set default value
+      outIt.Set(m_DefaultLabel);
+      }
+    progress.CompletedPixel();
+    }
+
+}
+/**
+ * PrintSelf Method
+ */
+template <class TInputImage, class TOutputImage, class TMaskImage>
+void
+ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage>
+::PrintSelf(std::ostream& os, itk::Indent indent) const
+{
+  Superclass::PrintSelf(os, indent);
+}
+} // End namespace otb
+#endif
diff --git a/Testing/Code/Learning/otbImageClassificationFilter.cxx b/Testing/Code/Learning/otbImageClassificationFilter.cxx
new file mode 100644
index 0000000000..0558b9f694
--- /dev/null
+++ b/Testing/Code/Learning/otbImageClassificationFilter.cxx
@@ -0,0 +1,73 @@
+/*=========================================================================
+
+ Program:   ORFEO Toolbox
+ Language:  C++
+ Date:      $Date$
+ Version:   $Revision$
+
+
+ Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
+ See OTBCopyright.txt for details.
+
+
+ This software is distributed WITHOUT ANY WARRANTY; without even
+ the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ PURPOSE.  See the above copyright notices for more information.
+
+ =========================================================================*/
+#include "otbImageClassificationFilter.h"
+#include "otbVectorImage.h"
+#include "otbImage.h"
+#include "otbImageFileReader.h"
+#include "otbImageFileWriter.h"
+#include "otbMachineLearningModelFactory.h"
+
+const unsigned int Dimension = 2;
+typedef double PixelType;
+typedef unsigned short LabeledPixelType;
+
+typedef otb::VectorImage<PixelType, Dimension> ImageType;
+typedef otb::Image<LabeledPixelType, Dimension> LabeledImageType;
+typedef otb::ImageClassificationFilter<ImageType, LabeledImageType> ClassificationFilterType;
+typedef ClassificationFilterType::ModelType ModelType;
+typedef ClassificationFilterType::ValueType ValueType;
+typedef ClassificationFilterType::LabelType LabelType;
+typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType;
+typedef otb::ImageFileReader<ImageType> ReaderType;
+typedef otb::ImageFileWriter<LabeledImageType> WriterType;
+
+int otbImageClassificationFilterNew(int argc, char * argv[])
+{
+  ClassificationFilterType::Pointer filter = ClassificationFilterType::New();
+  return EXIT_SUCCESS;
+}
+
+int otbImageClassificationFilter(int argc, char * argv[])
+{
+  const char * infname = argv[1];
+  const char * modelfname = argv[2];
+  const char * outfname = argv[3];
+
+  // Instantiating object
+  ClassificationFilterType::Pointer filter = ClassificationFilterType::New();
+
+  ReaderType::Pointer reader = ReaderType::New();
+  reader->SetFileName(infname);
+
+  ModelType::Pointer model;
+
+  model = MachineLearningModelFactoryType::CreateMachineLearningModel(modelfname,
+                                                                      MachineLearningModelFactoryType::ReadMode);
+
+  model->Load(modelfname);
+
+  filter->SetModel(model);
+  filter->SetInput(reader->GetOutput());
+
+  WriterType::Pointer writer = WriterType::New();
+  writer->SetInput(filter->GetOutput());
+  writer->SetFileName(outfname);
+  writer->Update();
+
+  return EXIT_SUCCESS;
+}
-- 
GitLab