/*=========================================================================

  Program:   ORFEO Toolbox
  Language:  C++
  Date:      $Date$
  Version:   $Revision$


  Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
  See OTBCopyright.txt for details.


     This software is distributed WITHOUT ANY WARRANTY; without even
     the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
     PURPOSE.  See the above copyright notices for more information.

=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

#include "itkUnaryFunctorImageFilter.h"
#include "otbChangeLabelImageFilter.h"
#include "otbStandardWriterWatcher.h"
#include "otbStatisticsXMLFileReader.h"
#include "otbShiftScaleVectorImageFilter.h"
#include "ImageDimensionalityReductionFilter.h"
#include "otbMultiToMonoChannelExtractROI.h"
#include "otbImageToVectorImageCastFilter.h"
#include "DimensionalityReductionModelFactory.h"

namespace otb
{
namespace Functor
{
/**
 * simple affine function : y = ax+b
 */
template<class TInput, class TOutput>
class AffineFunctor
{
public:
  typedef double InternalType;
  
  // constructor
  AffineFunctor() : m_A(1.0),m_B(0.0) {}
  
  // destructor
  virtual ~AffineFunctor() {}
  
  void SetA(InternalType a)
    {
    m_A = a;
    }
  
  void SetB(InternalType b)
    {
    m_B = b;
    }
  
  inline TOutput operator()(const TInput & x) const
    {
    return static_cast<TOutput>( static_cast<InternalType>(x)*m_A + m_B);
    }
private:
  InternalType m_A;
  InternalType m_B;
};
  
}

namespace Wrapper
{

class ImageDimensionalityReduction : public Application
{
public:
  /** Standard class typedefs. */
  typedef ImageDimensionalityReduction             Self;
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

  itkTypeMacro(ImageDimensionalityReduction, otb::Application);

  /** Filters typedef */
  typedef UInt8ImageType                                                                       MaskImageType;
  typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType>                   MeasurementType;
  typedef otb::StatisticsXMLFileReader<MeasurementType>                                        StatisticsReader;
  typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType>         RescalerType;
  typedef itk::UnaryFunctorImageFilter<
      FloatImageType,
      FloatImageType,
      otb::Functor::AffineFunctor<float,float> >                                               OutputRescalerType;
  typedef otb::ImageDimensionalityReductionFilter<FloatVectorImageType, FloatVectorImageType, MaskImageType>  DimensionalityReductionFilterType;
  typedef DimensionalityReductionFilterType::Pointer                                                    DimensionalityReductionFilterPointerType;
  typedef DimensionalityReductionFilterType::ModelType                                                  ModelType;
  typedef ModelType::Pointer                                                                   ModelPointerType;
  typedef DimensionalityReductionFilterType::ValueType                                                  ValueType;
  typedef DimensionalityReductionFilterType::LabelType                                                  LabelType;
  typedef otb::DimensionalityReductionModelFactory<ValueType, LabelType>                               DimensionalityReductionModelFactoryType;

protected:

  ~ImageDimensionalityReduction() ITK_OVERRIDE
  {
    DimensionalityReductionModelFactoryType::CleanFactories();
  }

private:
  void DoInit() ITK_OVERRIDE
  {
    SetName("DimensionalityReduction");
    SetDescription("Performs dimensionality reduction of the input image according to a dimensionality reduction model file.");

    // Documentation
    SetDocName("DimensionalityReduction");
    SetDocLongDescription("This application reduces the dimension of an input"
                          " image, based on a machine learning model file produced by"
                          " the TrainDimensionalityReduction application. Pixels of the "
                          "output image will contain the reduced values from"
                          "the model. The input pixels"
                          " can be optionally centered and reduced according "
                          "to the statistics file produced by the "
                          "ComputeImagesStatistics application. ");

    SetDocLimitations("The input image must contain the feature bands used for"
                      " the model training. "
                      "If a statistics file was used during training by the "
                      "Training application, it is mandatory to use the same "
                      "statistics file for reduction.");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso("TrainDimensionalityReduction, ComputeImagesStatistics");

    AddDocTag(Tags::Learning);

    AddParameter(ParameterType_InputImage, "in",  "Input Image");
    SetParameterDescription( "in", "The input image to predict.");

    AddParameter(ParameterType_InputImage,  "mask",   "Input Mask");
    SetParameterDescription( "mask", "The mask allow restricting "
      "classification of the input image to the area where mask pixel values "
      "are greater than 0.");
    MandatoryOff("mask");

    AddParameter(ParameterType_InputFilename, "model", "Model file");
    SetParameterDescription("model", "A dimensionality reduction model file (produced by "
                            "TrainRegression application).");

    AddParameter(ParameterType_InputFilename, "imstat", "Statistics file");
    SetParameterDescription("imstat", "A XML file containing mean and standard"
      " deviation to center and reduce samples before prediction "
      "(produced by ComputeImagesStatistics application). If this file contains"
                            "one more bands than the sample size, the last stat of last band will be"
                            "applied to expand the output predicted value");
    MandatoryOff("imstat");

    AddParameter(ParameterType_OutputImage, "out",  "Output Image");
    SetParameterDescription( "out", "Output image containing reduced values");

    AddRAMParameter();

   // Doc example parameter settings
    SetDocExampleParameterValue("in", "QB_1_ortho.tif");
    SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
    SetDocExampleParameterValue("model", "clsvmModelQB1.model");
    SetDocExampleParameterValue("out", "ReducedImageQB1.tif");
  }

  void DoUpdateParameters() ITK_OVERRIDE
  {
    // Nothing to do here : all parameters are independent
  }

  void DoExecute() ITK_OVERRIDE
  {
    // Load input image
    FloatVectorImageType::Pointer inImage = GetParameterImage("in");
    inImage->UpdateOutputInformation();
    unsigned int nbFeatures = inImage->GetNumberOfComponentsPerPixel();

    // Load DR model using a factory
    otbAppLogINFO("Loading model");
    m_Model = DimensionalityReductionModelFactoryType::CreateDimensionalityReductionModel(GetParameterString("model"),
                                                                          DimensionalityReductionModelFactoryType::ReadMode);

    if (m_Model.IsNull())
      {
      otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
      }

    m_Model->Load(GetParameterString("model"));
    otbAppLogINFO("Model loaded");

    // Classify
    m_ClassificationFilter = DimensionalityReductionFilterType::New();
    m_ClassificationFilter->SetModel(m_Model);
    
    FloatVectorImageType::Pointer outputImage = m_ClassificationFilter->GetOutput();

    // Normalize input image if asked
    if(IsParameterEnabled("imstat")  )
      {
      otbAppLogINFO("Input image normalization activated.");
      // Normalize input image (optional)
      StatisticsReader::Pointer  statisticsReader = StatisticsReader::New();
      MeasurementType  meanMeasurementVector;
      MeasurementType  stddevMeasurementVector;
      m_Rescaler = RescalerType::New();
      
      // Load input image statistics
      statisticsReader->SetFileName(GetParameterString("imstat"));
      meanMeasurementVector   = statisticsReader->GetStatisticVectorByName("mean");
      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
      otbAppLogINFO( "mean used: " << meanMeasurementVector );
      otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector );
      if (meanMeasurementVector.Size() != nbFeatures)
        {
        otbAppLogFATAL("Wrong number of components in statistics file : "<<meanMeasurementVector.Size());
        }
        
      // Rescale vector image
      m_Rescaler->SetScale(stddevMeasurementVector);
      m_Rescaler->SetShift(meanMeasurementVector);
      m_Rescaler->SetInput(inImage);

      m_ClassificationFilter->SetInput(m_Rescaler->GetOutput());
      }
    else
      {
      otbAppLogINFO("Input image normalization deactivated.");
      m_ClassificationFilter->SetInput(inImage);
      }


    if(IsParameterEnabled("mask"))
      {
      otbAppLogINFO("Using input mask");
      // Load mask image and cast into LabeledImageType
      MaskImageType::Pointer inMask = GetParameterUInt8Image("mask");

      m_ClassificationFilter->SetInputMask(inMask);
      }

    SetParameterOutputImage<FloatVectorImageType>("out", outputImage);

  }

  DimensionalityReductionFilterType::Pointer m_ClassificationFilter;
  ModelPointerType m_Model;
  RescalerType::Pointer m_Rescaler;
  OutputRescalerType::Pointer m_OutRescaler;
};


}
}

OTB_APPLICATION_EXPORT(otb::Wrapper::ImageDimensionalityReduction)