diff --git a/Code/Hyperspectral/otbMNFImageFilter.h b/Code/Hyperspectral/otbMNFImageFilter.h index 93b1a56c8e12b3b23c606000f72360d9bc85da1e..7146f074a73a770cd46b0e9bd0e65679942d65be 100644 --- a/Code/Hyperspectral/otbMNFImageFilter.h +++ b/Code/Hyperspectral/otbMNFImageFilter.h @@ -23,6 +23,7 @@ #include "otbStreamingStatisticsVectorImageFilter2.h" #include "otbMatrixMultiplyImageFilter.h" #include "otbNormalizeVectorImageFilter.h" +#include "otbPCAImageFilter.h" namespace otb { @@ -79,14 +80,16 @@ public: typedef typename CovarianceEstimatorFilterType::RealPixelType VectorType; typedef typename CovarianceEstimatorFilterType::MatrixObjectType MatrixObjectType; typedef typename MatrixObjectType::ComponentType MatrixType; + typedef typename MatrixType::InternalMatrixType InternalMatrixType; + typedef typename InternalMatrixType::element_type MatrixElementType; - typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType; + typedef MatrixMultiplyImageFilter< InputImageType, OutputImageType, RealType > TransformFilterType; typedef typename TransformFilterType::Pointer TransformFilterPointerType; typedef TNoiseImageFilter NoiseImageFilterType; typedef typename NoiseImageFilterType::Pointer NoiseImageFilterPointerType; - typedef NormalizeVectorImageFilter< TInputImage, TOutputImage > NormalizeFilterType; + typedef NormalizeVectorImageFilter< InputImageType, OutputImageType > NormalizeFilterType; typedef typename NormalizeFilterType::Pointer NormalizeFilterPointerType; /** @@ -101,6 +104,7 @@ public: itkGetMacro(Transformer, TransformFilterType *); itkGetMacro(NoiseImageFilter, NoiseImageFilterType *); + /** Normalization only impact the use of variance. The data is always centered */ itkGetMacro(UseNormalization,bool); itkSetMacro(UseNormalization,bool); @@ -108,7 +112,6 @@ public: void SetMeanValues ( const VectorType & vec ) { m_MeanValues = vec; - m_UseNormalization = true; m_GivenMeanValues = true; } @@ -169,7 +172,8 @@ protected: virtual void ForwardGenerateData(); virtual void ReverseGenerateData(); - void GetTransformationMatrixFromCovarianceMatrix(); + /** Specific functionality of MNF */ + virtual void GetTransformationMatrixFromCovarianceMatrix(); /** Internal attributes */ unsigned int m_NumberOfPrincipalComponentsRequired; diff --git a/Code/Hyperspectral/otbMNFImageFilter.txx b/Code/Hyperspectral/otbMNFImageFilter.txx index c20d8b052f5c7ddbc5e33b4cf280b07a7ca10151..019d2279fa2bf2a0a42042ae36ee5ab4c9f9cafa 100644 --- a/Code/Hyperspectral/otbMNFImageFilter.txx +++ b/Code/Hyperspectral/otbMNFImageFilter.txx @@ -140,28 +140,24 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf typename InputImageType::Pointer inputImgPtr = const_cast<InputImageType*>( this->GetInput() ); + if ( m_GivenMeanValues ) + m_Normalizer->SetMean( this->GetMeanValues() ); + if ( m_UseNormalization ) { - if ( m_GivenMeanValues ) - m_Normalizer->SetMean( this->GetMeanValues() ); - if ( m_GivenStdDevValues ) m_Normalizer->SetStdDev( this->GetStdDevValues() ); - - m_Normalizer->SetInput( inputImgPtr ); - - std::cerr << m_Normalizer << "\n"; } + else + m_Normalizer->SetUseStdDev( false ); + + m_Normalizer->SetInput( inputImgPtr ); if ( !m_GivenTransformationMatrix ) { if ( !m_GivenNoiseCovarianceMatrix ) { - if ( m_UseNormalization ) - m_NoiseImageFilter->SetInput( m_Normalizer->GetOutput() ); - else - m_NoiseImageFilter->SetInput( inputImgPtr ); - + m_NoiseImageFilter->SetInput( m_Normalizer->GetOutput() ); m_NoiseCovarianceEstimator->SetInput( m_NoiseImageFilter->GetOutput() ); m_NoiseCovarianceEstimator->Update(); @@ -170,10 +166,7 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf if ( !m_GivenCovarianceMatrix ) { - if ( m_UseNormalization ) - m_CovarianceEstimator->SetInput( m_Normalizer->GetOutput() ); - else - m_CovarianceEstimator->SetInput( inputImgPtr ); + m_CovarianceEstimator->SetInput( m_Normalizer->GetOutput() ); m_CovarianceEstimator->Update(); m_CovarianceMatrix = m_CovarianceEstimator->GetCovariance(); @@ -185,7 +178,15 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf { // Prevents from multiple transpose in pipeline m_IsTransformationMatrixForward = true; - m_TransformationMatrix = m_TransformationMatrix.GetTranspose(); + if ( m_TransformationMatrix.Rows() == m_TransformationMatrix.Cols() ) + { + m_TransformationMatrix = m_TransformationMatrix.GetInverse(); + } + else + { + vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() ); + m_TransformationMatrix = invertor.pinverse(); + } } if ( m_TransformationMatrix.GetVnlMatrix().empty() ) @@ -195,24 +196,22 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf ITK_LOCATION); } - if ( m_UseNormalization ) - { - m_Transformer->SetInput( m_Normalizer->GetOutput() ); - - if ( !m_GivenMeanValues ) - m_MeanValues = m_Normalizer->GetFunctor().GetMean(); - - if ( !m_GivenStdDevValues ) - m_StdDevValues = m_Normalizer->GetFunctor().GetStdDev(); - } - else - m_Transformer->SetInput( inputImgPtr ); - + m_Transformer->SetInput( m_Normalizer->GetOutput() ); m_Transformer->SetMatrix( m_TransformationMatrix.GetVnlMatrix() ); m_Transformer->GraftOutput( this->GetOutput() ); m_Transformer->Update(); this->GraftOutput( m_Transformer->GetOutput() ); + + /** Once the Normalizer has been updated */ + if ( !m_GivenMeanValues ) + m_MeanValues = m_Normalizer->GetFunctor().GetMean(); + + if ( m_UseNormalization ) + { + if ( !m_GivenStdDevValues ) + m_StdDevValues = m_Normalizer->GetFunctor().GetStdDev(); + } } template <class TInputImage, class TOutputImage, @@ -241,13 +240,29 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf GetTransformationMatrixFromCovarianceMatrix(); m_IsTransformationMatrixForward = false; - m_TransformationMatrix = m_TransformationMatrix.GetTranspose(); + if ( m_TransformationMatrix.Rows() == m_TransformationMatrix.Cols() ) + m_TransformationMatrix = vnl_matrix_inverse< MatrixElementType > + ( m_TransformationMatrix.GetTranspose() ); + else + { + vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() ); + m_TransformationMatrix = invertor.inverse(); + } } else if ( m_IsTransformationMatrixForward ) { // Prevent from multiple transpose in pipeline m_IsTransformationMatrixForward = false; - m_TransformationMatrix = m_TransformationMatrix.GetTranspose(); + if ( m_TransformationMatrix.Rows() == m_TransformationMatrix.Cols() ) + { + m_TransformationMatrix = vnl_matrix_inverse< MatrixElementType > + ( m_TransformationMatrix.GetTranspose() ); + } + else + { + vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() ); + m_TransformationMatrix = invertor.pinverse(); + } } if ( m_TransformationMatrix.GetVnlMatrix().empty() ) @@ -260,12 +275,18 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf m_Transformer->SetInput( this->GetInput() ); m_Transformer->SetMatrix( m_TransformationMatrix.GetVnlMatrix() ); + if ( !m_GivenMeanValues ) + { + throw itk::ExceptionObject( __FILE__, __LINE__, + "Initial means required for correct data centering", + ITK_LOCATION ); + } if ( m_UseNormalization ) { - if ( !m_GivenMeanValues || !m_GivenStdDevValues ) + if ( !m_GivenStdDevValues ) { throw itk::ExceptionObject( __FILE__, __LINE__, - "Initial means and StdDevs required for de-normalization", + "Initial StdDevs required for de-normalization", ITK_LOCATION ); } @@ -278,22 +299,21 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf for ( unsigned int i = 0; i < m_MeanValues.Size(); i++ ) revMean[i] = - m_MeanValues[i] / m_StdDevValues[i]; m_Normalizer->SetMean( revMean ); - - m_Normalizer->SetInput( m_Transformer->GetOutput() ); - m_Normalizer->GraftOutput( this->GetOutput() ); - m_Normalizer->Update(); - - std::cerr << m_Normalizer << "\n"; - - this->GraftOutput( m_Normalizer->GetOutput() ); } else { - m_Transformer->GraftOutput( this->GetOutput() ); - m_Transformer->Update(); - - this->GraftOutput( m_Transformer->GetOutput() ); + VectorType revMean ( m_MeanValues.Size() ); + for ( unsigned int i = 0; i < m_MeanValues.Size(); i++ ) + revMean[i] = - m_MeanValues[i] ; + m_Normalizer->SetMean( revMean ); + m_Normalizer->SetUseStdDev( false ); } + + m_Normalizer->SetInput( m_Transformer->GetOutput() ); + m_Normalizer->GraftOutput( this->GetOutput() ); + m_Normalizer->Update(); + + this->GraftOutput( m_Normalizer->GetOutput() ); } template <class TInputImage, class TOutputImage, @@ -306,18 +326,16 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf MatrixType Id ( m_NoiseCovarianceMatrix ); Id.SetIdentity(); - typename MatrixType::InternalMatrixType Ax_inv = - vnl_matrix_inverse< typename MatrixType::InternalMatrixType::element_type> - ( m_CovarianceMatrix.GetVnlMatrix() ); - typename MatrixType::InternalMatrixType An = m_NoiseCovarianceMatrix.GetVnlMatrix(); - typename MatrixType::InternalMatrixType W = An * Ax_inv; - typename MatrixType::InternalMatrixType I = Id.GetVnlMatrix(); + InternalMatrixType Ax_inv = vnl_matrix_inverse< MatrixElementType > ( m_CovarianceMatrix.GetVnlMatrix() ); + InternalMatrixType An = m_NoiseCovarianceMatrix.GetVnlMatrix(); + InternalMatrixType W = An * Ax_inv; + InternalMatrixType I = Id.GetVnlMatrix(); vnl_generalized_eigensystem solver ( W, I ); - typename MatrixType::InternalMatrixType transf = solver.V; - typename MatrixType::InternalMatrixType normMat - = transf.transpose() * m_CovarianceMatrix.GetVnlMatrix() * transf; + InternalMatrixType transf = solver.V; + InternalMatrixType normMat //= transf.transpose() * An * transf; + = transf.transpose() * Ax_inv * transf; for ( unsigned int i = 0; i < transf.rows(); i++ ) {