Skip to content
Snippets Groups Projects
Commit 712cc8d9 authored by Grégoire Mercier's avatar Grégoire Mercier
Browse files

BUG: MNF with dimension reduction

parent ccbdb0c9
Branches
Tags
No related merge requests found
......@@ -23,6 +23,7 @@
#include "otbStreamingStatisticsVectorImageFilter2.h"
#include "otbMatrixMultiplyImageFilter.h"
#include "otbNormalizeVectorImageFilter.h"
#include "otbPCAImageFilter.h"
namespace otb {
......@@ -79,14 +80,16 @@ public:
typedef typename CovarianceEstimatorFilterType::RealPixelType VectorType;
typedef typename CovarianceEstimatorFilterType::MatrixObjectType MatrixObjectType;
typedef typename MatrixObjectType::ComponentType MatrixType;
typedef typename MatrixType::InternalMatrixType InternalMatrixType;
typedef typename InternalMatrixType::element_type MatrixElementType;
typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType;
typedef MatrixMultiplyImageFilter< InputImageType, OutputImageType, RealType > TransformFilterType;
typedef typename TransformFilterType::Pointer TransformFilterPointerType;
typedef TNoiseImageFilter NoiseImageFilterType;
typedef typename NoiseImageFilterType::Pointer NoiseImageFilterPointerType;
typedef NormalizeVectorImageFilter< TInputImage, TOutputImage > NormalizeFilterType;
typedef NormalizeVectorImageFilter< InputImageType, OutputImageType > NormalizeFilterType;
typedef typename NormalizeFilterType::Pointer NormalizeFilterPointerType;
/**
......@@ -101,6 +104,7 @@ public:
itkGetMacro(Transformer, TransformFilterType *);
itkGetMacro(NoiseImageFilter, NoiseImageFilterType *);
/** Normalization only impact the use of variance. The data is always centered */
itkGetMacro(UseNormalization,bool);
itkSetMacro(UseNormalization,bool);
......@@ -108,7 +112,6 @@ public:
void SetMeanValues ( const VectorType & vec )
{
m_MeanValues = vec;
m_UseNormalization = true;
m_GivenMeanValues = true;
}
......@@ -169,7 +172,8 @@ protected:
virtual void ForwardGenerateData();
virtual void ReverseGenerateData();
void GetTransformationMatrixFromCovarianceMatrix();
/** Specific functionality of MNF */
virtual void GetTransformationMatrixFromCovarianceMatrix();
/** Internal attributes */
unsigned int m_NumberOfPrincipalComponentsRequired;
......
......@@ -140,28 +140,24 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
typename InputImageType::Pointer inputImgPtr
= const_cast<InputImageType*>( this->GetInput() );
if ( m_GivenMeanValues )
m_Normalizer->SetMean( this->GetMeanValues() );
if ( m_UseNormalization )
{
if ( m_GivenMeanValues )
m_Normalizer->SetMean( this->GetMeanValues() );
if ( m_GivenStdDevValues )
m_Normalizer->SetStdDev( this->GetStdDevValues() );
m_Normalizer->SetInput( inputImgPtr );
std::cerr << m_Normalizer << "\n";
}
else
m_Normalizer->SetUseStdDev( false );
m_Normalizer->SetInput( inputImgPtr );
if ( !m_GivenTransformationMatrix )
{
if ( !m_GivenNoiseCovarianceMatrix )
{
if ( m_UseNormalization )
m_NoiseImageFilter->SetInput( m_Normalizer->GetOutput() );
else
m_NoiseImageFilter->SetInput( inputImgPtr );
m_NoiseImageFilter->SetInput( m_Normalizer->GetOutput() );
m_NoiseCovarianceEstimator->SetInput( m_NoiseImageFilter->GetOutput() );
m_NoiseCovarianceEstimator->Update();
......@@ -170,10 +166,7 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
if ( !m_GivenCovarianceMatrix )
{
if ( m_UseNormalization )
m_CovarianceEstimator->SetInput( m_Normalizer->GetOutput() );
else
m_CovarianceEstimator->SetInput( inputImgPtr );
m_CovarianceEstimator->SetInput( m_Normalizer->GetOutput() );
m_CovarianceEstimator->Update();
m_CovarianceMatrix = m_CovarianceEstimator->GetCovariance();
......@@ -185,7 +178,15 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
{
// Prevents from multiple transpose in pipeline
m_IsTransformationMatrixForward = true;
m_TransformationMatrix = m_TransformationMatrix.GetTranspose();
if ( m_TransformationMatrix.Rows() == m_TransformationMatrix.Cols() )
{
m_TransformationMatrix = m_TransformationMatrix.GetInverse();
}
else
{
vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
m_TransformationMatrix = invertor.pinverse();
}
}
if ( m_TransformationMatrix.GetVnlMatrix().empty() )
......@@ -195,24 +196,22 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
ITK_LOCATION);
}
if ( m_UseNormalization )
{
m_Transformer->SetInput( m_Normalizer->GetOutput() );
if ( !m_GivenMeanValues )
m_MeanValues = m_Normalizer->GetFunctor().GetMean();
if ( !m_GivenStdDevValues )
m_StdDevValues = m_Normalizer->GetFunctor().GetStdDev();
}
else
m_Transformer->SetInput( inputImgPtr );
m_Transformer->SetInput( m_Normalizer->GetOutput() );
m_Transformer->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
m_Transformer->GraftOutput( this->GetOutput() );
m_Transformer->Update();
this->GraftOutput( m_Transformer->GetOutput() );
/** Once the Normalizer has been updated */
if ( !m_GivenMeanValues )
m_MeanValues = m_Normalizer->GetFunctor().GetMean();
if ( m_UseNormalization )
{
if ( !m_GivenStdDevValues )
m_StdDevValues = m_Normalizer->GetFunctor().GetStdDev();
}
}
template <class TInputImage, class TOutputImage,
......@@ -241,13 +240,29 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
GetTransformationMatrixFromCovarianceMatrix();
m_IsTransformationMatrixForward = false;
m_TransformationMatrix = m_TransformationMatrix.GetTranspose();
if ( m_TransformationMatrix.Rows() == m_TransformationMatrix.Cols() )
m_TransformationMatrix = vnl_matrix_inverse< MatrixElementType >
( m_TransformationMatrix.GetTranspose() );
else
{
vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
m_TransformationMatrix = invertor.inverse();
}
}
else if ( m_IsTransformationMatrixForward )
{
// Prevent from multiple transpose in pipeline
m_IsTransformationMatrixForward = false;
m_TransformationMatrix = m_TransformationMatrix.GetTranspose();
if ( m_TransformationMatrix.Rows() == m_TransformationMatrix.Cols() )
{
m_TransformationMatrix = vnl_matrix_inverse< MatrixElementType >
( m_TransformationMatrix.GetTranspose() );
}
else
{
vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
m_TransformationMatrix = invertor.pinverse();
}
}
if ( m_TransformationMatrix.GetVnlMatrix().empty() )
......@@ -260,12 +275,18 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
m_Transformer->SetInput( this->GetInput() );
m_Transformer->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
if ( !m_GivenMeanValues )
{
throw itk::ExceptionObject( __FILE__, __LINE__,
"Initial means required for correct data centering",
ITK_LOCATION );
}
if ( m_UseNormalization )
{
if ( !m_GivenMeanValues || !m_GivenStdDevValues )
if ( !m_GivenStdDevValues )
{
throw itk::ExceptionObject( __FILE__, __LINE__,
"Initial means and StdDevs required for de-normalization",
"Initial StdDevs required for de-normalization",
ITK_LOCATION );
}
......@@ -278,22 +299,21 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
for ( unsigned int i = 0; i < m_MeanValues.Size(); i++ )
revMean[i] = - m_MeanValues[i] / m_StdDevValues[i];
m_Normalizer->SetMean( revMean );
m_Normalizer->SetInput( m_Transformer->GetOutput() );
m_Normalizer->GraftOutput( this->GetOutput() );
m_Normalizer->Update();
std::cerr << m_Normalizer << "\n";
this->GraftOutput( m_Normalizer->GetOutput() );
}
else
{
m_Transformer->GraftOutput( this->GetOutput() );
m_Transformer->Update();
this->GraftOutput( m_Transformer->GetOutput() );
VectorType revMean ( m_MeanValues.Size() );
for ( unsigned int i = 0; i < m_MeanValues.Size(); i++ )
revMean[i] = - m_MeanValues[i] ;
m_Normalizer->SetMean( revMean );
m_Normalizer->SetUseStdDev( false );
}
m_Normalizer->SetInput( m_Transformer->GetOutput() );
m_Normalizer->GraftOutput( this->GetOutput() );
m_Normalizer->Update();
this->GraftOutput( m_Normalizer->GetOutput() );
}
template <class TInputImage, class TOutputImage,
......@@ -306,18 +326,16 @@ MNFImageFilter< TInputImage, TOutputImage, TNoiseImageFilter, TDirectionOfTransf
MatrixType Id ( m_NoiseCovarianceMatrix );
Id.SetIdentity();
typename MatrixType::InternalMatrixType Ax_inv =
vnl_matrix_inverse< typename MatrixType::InternalMatrixType::element_type>
( m_CovarianceMatrix.GetVnlMatrix() );
typename MatrixType::InternalMatrixType An = m_NoiseCovarianceMatrix.GetVnlMatrix();
typename MatrixType::InternalMatrixType W = An * Ax_inv;
typename MatrixType::InternalMatrixType I = Id.GetVnlMatrix();
InternalMatrixType Ax_inv = vnl_matrix_inverse< MatrixElementType > ( m_CovarianceMatrix.GetVnlMatrix() );
InternalMatrixType An = m_NoiseCovarianceMatrix.GetVnlMatrix();
InternalMatrixType W = An * Ax_inv;
InternalMatrixType I = Id.GetVnlMatrix();
vnl_generalized_eigensystem solver ( W, I );
typename MatrixType::InternalMatrixType transf = solver.V;
typename MatrixType::InternalMatrixType normMat
= transf.transpose() * m_CovarianceMatrix.GetVnlMatrix() * transf;
InternalMatrixType transf = solver.V;
InternalMatrixType normMat //= transf.transpose() * An * transf;
= transf.transpose() * Ax_inv * transf;
for ( unsigned int i = 0; i < transf.rows(); i++ )
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment