diff --git a/Modules/Applications/AppClassification/app/otbImageRegression.cxx b/Modules/Applications/AppClassification/app/otbImageRegression.cxx index 56e07622ad80ede4a6ba83c7da6f413eaf000338..f1892417da6d1e853d3cb6412f890d8c7ec7ade6 100644 --- a/Modules/Applications/AppClassification/app/otbImageRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbImageRegression.cxx @@ -38,37 +38,41 @@ namespace Functor /** * simple affine function : y = ax+b */ -template<class TInput, class TOutput> +template <class TInput, class TOutput> class AffineFunctor { public: typedef double InternalType; - + // constructor - AffineFunctor() : m_A(1.0),m_B(0.0) {} - + AffineFunctor() : m_A(1.0), m_B(0.0) + { + } + // destructor - virtual ~AffineFunctor() {} - + virtual ~AffineFunctor() + { + } + void SetA(InternalType a) - { + { m_A = a; - } - + } + void SetB(InternalType b) - { + { m_B = b; - } - - inline TOutput operator()(const TInput & x) const - { - return static_cast<TOutput>( static_cast<InternalType>(x)*m_A + m_B); - } + } + + inline TOutput operator()(const TInput& x) const + { + return static_cast<TOutput>(static_cast<InternalType>(x) * m_A + m_B); + } + private: InternalType m_A; InternalType m_B; }; - } namespace Wrapper @@ -89,28 +93,24 @@ public: itkTypeMacro(ImageRegression, otb::Application); /** Filters typedef */ - typedef UInt8ImageType MaskImageType; - typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType> MeasurementType; - typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader; - typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType> RescalerType; - typedef itk::UnaryFunctorImageFilter< - FloatImageType, - FloatImageType, - otb::Functor::AffineFunctor<float,float> > OutputRescalerType; - typedef otb::ImageClassificationFilter<FloatVectorImageType, FloatImageType, MaskImageType> ClassificationFilterType; - typedef ClassificationFilterType::Pointer ClassificationFilterPointerType; - typedef ClassificationFilterType::ModelType ModelType; - typedef ModelType::Pointer ModelPointerType; - typedef ClassificationFilterType::ValueType ValueType; - typedef ClassificationFilterType::LabelType LabelType; - typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; + typedef UInt8ImageType MaskImageType; + typedef itk::VariableLengthVector<FloatVectorImageType::InternalPixelType> MeasurementType; + typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader; + typedef otb::ShiftScaleVectorImageFilter<FloatVectorImageType, FloatVectorImageType> RescalerType; + typedef itk::UnaryFunctorImageFilter<FloatImageType, FloatImageType, otb::Functor::AffineFunctor<float, float>> OutputRescalerType; + typedef otb::ImageClassificationFilter<FloatVectorImageType, FloatImageType, MaskImageType> ClassificationFilterType; + typedef ClassificationFilterType::Pointer ClassificationFilterPointerType; + typedef ClassificationFilterType::ModelType ModelType; + typedef ModelType::Pointer ModelPointerType; + typedef ClassificationFilterType::ValueType ValueType; + typedef ClassificationFilterType::LabelType LabelType; + typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; protected: - ~ImageRegression() override - { + { MachineLearningModelFactoryType::CleanFactories(); - } + } private: void DoInit() override @@ -119,58 +119,63 @@ private: SetDescription("Performs a prediction of the input image according to a regression model file."); // Documentation - SetDocLongDescription("This application predict output values from an input " - "image, based on a regression model file produced either by " - "TrainVectorRegression or TrainImagesRegression. " - "Pixels of the output image will contain the predicted values from " - "the regression model (single band). The input pixels " - "can be optionally centered and reduced according " - "to the statistics file produced by the " - "ComputeImagesStatistics application. An optional " - "input mask can be provided, in which case only " - "input image pixels whose corresponding mask value " - "is greater than zero will be processed. The remaining " - "of pixels will be given the value zero in the output " - "image."); - - SetDocLimitations("The input image must contain the feature bands used for " - "the model training. " - "If a statistics file was used during training by the " - "TrainRegression, it is mandatory to use the same " - "statistics file for prediction. If an input mask is " - "used, its size must match the input image size."); + SetDocLongDescription( + "This application predict output values from an input " + "image, based on a regression model file produced either by " + "TrainVectorRegression or TrainImagesRegression. " + "Pixels of the output image will contain the predicted values from " + "the regression model (single band). The input pixels " + "can be optionally centered and reduced according " + "to the statistics file produced by the " + "ComputeImagesStatistics application. An optional " + "input mask can be provided, in which case only " + "input image pixels whose corresponding mask value " + "is greater than zero will be processed. The remaining " + "of pixels will be given the value zero in the output " + "image."); + + SetDocLimitations( + "The input image must contain the feature bands used for " + "the model training. " + "If a statistics file was used during training by the " + "TrainRegression, it is mandatory to use the same " + "statistics file for prediction. If an input mask is " + "used, its size must match the input image size."); SetDocAuthors("OTB-Team"); SetDocSeeAlso("TrainImagesRegression, TrainVectorRegression, VectorRegression, ComputeImagesStatistics"); AddDocTag(Tags::Learning); AddParameter(ParameterType_InputImage, "in", "Input Image"); - SetParameterDescription( "in", "The input image to predict."); + SetParameterDescription("in", "The input image to predict."); AddParameter(ParameterType_InputImage, "mask", "Input Mask"); - SetParameterDescription( "mask", "The mask restrict the " - "classification of the input image to the area where mask pixel values " - "are greater than zero."); + SetParameterDescription("mask", + "The mask restrict the " + "classification of the input image to the area where mask pixel values " + "are greater than zero."); MandatoryOff("mask"); AddParameter(ParameterType_InputFilename, "model", "Model file"); - SetParameterDescription("model", "A regression model file (produced either by " - "TrainVectorRegression application or the TrainImagesRegression application)."); + SetParameterDescription("model", + "A regression model file (produced either by " + "TrainVectorRegression application or the TrainImagesRegression application)."); AddParameter(ParameterType_InputFilename, "imstat", "Statistics file"); - SetParameterDescription("imstat", "An XML file containing mean and standard" - " deviation to center and reduce samples before prediction " - "(produced by the ComputeImagesStatistics application). If this file contains " - "one more band than the sample size, the last stat of the last band will be" - "applied to expand the output predicted value."); + SetParameterDescription("imstat", + "An XML file containing mean and standard" + " deviation to center and reduce samples before prediction " + "(produced by the ComputeImagesStatistics application). If this file contains " + "one more band than the sample size, the last stat of the last band will be" + "applied to expand the output predicted value."); MandatoryOff("imstat"); AddParameter(ParameterType_OutputImage, "out", "Output Image"); - SetParameterDescription( "out", "Output image containing predicted values"); + SetParameterDescription("out", "Output image containing predicted values"); AddRAMParameter(); - // Doc example parameter settings + // Doc example parameter settings SetDocExampleParameterValue("in", "QB_1_ortho.tif"); SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml"); SetDocExampleParameterValue("model", "clsvmModelQB1.svm"); @@ -187,97 +192,88 @@ private: void DoExecute() override { // Load input image - FloatVectorImageType::Pointer inImage = GetParameterImage("in"); + auto inImage = GetParameterImage("in"); inImage->UpdateOutputInformation(); unsigned int nbFeatures = inImage->GetNumberOfComponentsPerPixel(); // Load svm model otbAppLogINFO("Loading model"); - m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), - MachineLearningModelFactoryType::ReadMode); + auto model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"), MachineLearningModelFactoryType::ReadMode); - if (m_Model.IsNull()) - { + if (model.IsNull()) + { otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type"); - } + } - m_Model->Load(GetParameterString("model")); - m_Model->SetRegressionMode(true); + model->Load(GetParameterString("model")); + model->SetRegressionMode(true); otbAppLogINFO("Model loaded"); // Classify - m_ClassificationFilter = ClassificationFilterType::New(); - m_ClassificationFilter->SetModel(m_Model); - - FloatImageType::Pointer outputImage = m_ClassificationFilter->GetOutput(); + auto classificationFilter = ClassificationFilterType::New(); + classificationFilter->SetModel(model); + + auto outputImage = classificationFilter->GetOutput(); + RescalerType::Pointer rescaler; + OutputRescalerType::Pointer outRescaler; // Normalize input image if asked - if(IsParameterEnabled("imstat") ) - { + if (IsParameterEnabled("imstat")) + { otbAppLogINFO("Input image normalization activated."); // Normalize input image (optional) - StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); - MeasurementType meanMeasurementVector; - MeasurementType stddevMeasurementVector; - m_Rescaler = RescalerType::New(); + auto statisticsReader = StatisticsReader::New(); + MeasurementType meanMeasurementVector; + MeasurementType stddevMeasurementVector; + rescaler = RescalerType::New(); // Load input image statistics statisticsReader->SetFileName(GetParameterString("imstat")); meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); - otbAppLogINFO( "mean used: " << meanMeasurementVector ); - otbAppLogINFO( "standard deviation used: " << stddevMeasurementVector ); + otbAppLogINFO("mean used: " << meanMeasurementVector); + otbAppLogINFO("standard deviation used: " << stddevMeasurementVector); if (meanMeasurementVector.Size() == nbFeatures + 1) - { - double outMean = meanMeasurementVector[nbFeatures]; + { + double outMean = meanMeasurementVector[nbFeatures]; double outStdDev = stddevMeasurementVector[nbFeatures]; - meanMeasurementVector.SetSize(nbFeatures,false); - stddevMeasurementVector.SetSize(nbFeatures,false); - m_OutRescaler = OutputRescalerType::New(); - m_OutRescaler->SetInput(m_ClassificationFilter->GetOutput()); - m_OutRescaler->GetFunctor().SetA(outStdDev); - m_OutRescaler->GetFunctor().SetB(outMean); - outputImage = m_OutRescaler->GetOutput(); - } + meanMeasurementVector.SetSize(nbFeatures, false); + stddevMeasurementVector.SetSize(nbFeatures, false); + outRescaler = OutputRescalerType::New(); + outRescaler->SetInput(classificationFilter->GetOutput()); + outRescaler->GetFunctor().SetA(outStdDev); + outRescaler->GetFunctor().SetB(outMean); + outputImage = outRescaler->GetOutput(); + } else if (meanMeasurementVector.Size() != nbFeatures) - { - otbAppLogFATAL("Wrong number of components in statistics file : "<<meanMeasurementVector.Size()); - } - + { + otbAppLogFATAL("Wrong number of components in statistics file : " << meanMeasurementVector.Size()); + } + // Rescale vector image - m_Rescaler->SetScale(stddevMeasurementVector); - m_Rescaler->SetShift(meanMeasurementVector); - m_Rescaler->SetInput(inImage); + rescaler->SetScale(stddevMeasurementVector); + rescaler->SetShift(meanMeasurementVector); + rescaler->SetInput(inImage); - m_ClassificationFilter->SetInput(m_Rescaler->GetOutput()); - } + classificationFilter->SetInput(rescaler->GetOutput()); + } else - { + { otbAppLogINFO("Input image normalization deactivated."); - m_ClassificationFilter->SetInput(inImage); - } - + classificationFilter->SetInput(inImage); + } - if(IsParameterEnabled("mask")) - { + if (IsParameterEnabled("mask")) + { otbAppLogINFO("Using input mask"); // Load mask image and cast into LabeledImageType - MaskImageType::Pointer inMask = GetParameterUInt8Image("mask"); - - m_ClassificationFilter->SetInputMask(inMask); - } + auto inMask = GetParameterUInt8Image("mask"); + classificationFilter->SetInputMask(inMask); + } SetParameterOutputImage<FloatImageType>("out", outputImage); - + RegisterPipeline(); } - - ClassificationFilterType::Pointer m_ClassificationFilter; - ModelPointerType m_Model; - RescalerType::Pointer m_Rescaler; - OutputRescalerType::Pointer m_OutRescaler; - }; - - } }