Skip to content
Snippets Groups Projects
Commit c90dfee7 authored by Grégoire Mercier's avatar Grégoire Mercier
Browse files

ENH: Center data on PCA

parent 6c68fbe4
Branches
Tags
No related merge requests found
......@@ -64,6 +64,11 @@ public:
m_Mean[i] = static_cast< RealType >( m[i] );
}
RealVectorType GetMean() const
{
return this->m_Mean;
}
template < class T>
void SetStdDev ( const itk::VariableLengthVector<T> & sigma )
{
......@@ -94,6 +99,11 @@ public:
}
}
RealVectorType GetStdDev() const
{
return this->m_StdDev;
}
protected:
RealVectorType m_Mean;
......
......@@ -22,6 +22,7 @@
#include "itkImageToImageFilter.h"
#include "otbStreamingStatisticsVectorImageFilter2.h"
#include "otbMatrixMultiplyImageFilter.h"
#include "otbNormalizeVectorImageFilter.h"
namespace otb
......@@ -81,6 +82,8 @@ public:
typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType;
typedef typename TransformFilterType::Pointer TransformFilterPointerType;
typedef NormalizeVetorImageFilter< TInputImage, TOutputImage > NormalizeFilterType;
typedef typename NormalizeFilterType::Pointer NormalizeFilterPointerType;
/**
* Set/Get the number of required largest principal components.
......@@ -120,7 +123,14 @@ public:
}
itkGetConstMacro(EigenValues,VectorType);
itkGetConstMacro(MeanValues,VectorType);
void SetMeanValues ( const VectorType & data )
{
m_GivenMeanValues = true;
m_MeanValues = data;
this->Modified();
}
protected:
PCAImageFilter();
......@@ -151,16 +161,19 @@ protected:
/** Internal attributes */
unsigned int m_NumberOfPrincipalComponentsRequired;
bool m_GivenMeanValues;
bool m_GivenCovarianceMatrix;
bool m_GivenTransformationMatrix;
bool m_IsTransformationMatrixForward;
VectorType m_MeanValues;
MatrixType m_CovarianceMatrix;
VectorType m_EigenValues;
MatrixType m_TransformationMatrix;
CovarianceEstimatorFilterPointerType m_CovarianceEstimator;
TransformFilterPointerType m_Transformer;
NormalizeFilterPointerType m_Normalizer;
};
} // end of namespace otb
......
......@@ -35,12 +35,14 @@ PCAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
this->SetNumberOfRequiredInputs(1);
m_NumberOfPrincipalComponentsRequired = 0;
m_GivenMeanValues = false;
m_GivenCovarianceMatrix = false;
m_GivenTransformationMatrix = false;
m_IsTransformationMatrixForward = true;
m_CovarianceEstimator = CovarianceEstimatorFilterType::New();
m_Transformer = TransformFilterType::New();
m_Normalizer = NormalizeFilterType::New();
}
template < class TInputImage, class TOutputImage,
......@@ -131,13 +133,24 @@ PCAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
{
if ( !m_GivenCovarianceMatrix )
{
otbGenericMsgDebugMacro(<< "Covariance estimation");
m_Normalizer->SetInput( inputImgPtr );
m_Normalizer->SetUseStdDev( false );
m_CovarianceEstimator->SetInput( inputImgPtr );
m_CovarianceEstimator->Update();
if ( m_GivenMeanValues )
m_Normalizer->SetMean( m_MeanValues );
m_Normalizer->Update();
if ( !m_GivenMeanValues )
m_MeanValues = m_Normalizer->GetFunctor().GetMean();
m_CovarianceMatrix = m_CovarianceEstimator->GetCovariance();
//m_CovarianceMatrix = m_CovarianceEstimator->GetCorrelation();
m_Transformer->SetInput( m_Normalizer->GetOutput() );
}
else
{
m_Transformer->SetInput( inputImgPtr );
}
GetTransformationMatrixFromCovarianceMatrix();
......@@ -145,6 +158,8 @@ PCAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
else if ( !m_IsTransformationMatrixForward )
{
m_TransformationMatrix = m_TransformationMatrix.GetTranspose();
m_Transformer->SetInput( inputImgPtr );
}
if ( m_TransformationMatrix.GetVnlMatrix().empty() )
......@@ -154,7 +169,6 @@ PCAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
ITK_LOCATION);
}
m_Transformer->SetInput( inputImgPtr );
m_Transformer->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
m_Transformer->GraftOutput( this->GetOutput() );
m_Transformer->Update();
......@@ -194,10 +208,27 @@ PCAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
m_Transformer->SetInput( this->GetInput() );
m_Transformer->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
m_Transformer->GraftOutput( this->GetOutput() );
m_Transformer->Update();
this->GraftOutput( m_Transformer->GetOutput() );
if ( m_GivenMeanValues )
{
VectorType revMean ( m_MeanValues.Size() );
for ( unsigned int i = 0; i < m_MeanValues.Size(); i++ )
revMean[i] = -m_MeanValues[i];
m_Normalizer->SetIntput( m_Transformer->GetOutput() );
m_Normalizer->SetMean( revMean );
m_Normalizer->SetUseStdDev( false );
m_Normalizer->GraftOutput( this->GetOutput() );
m_Normalizer->Update();
this->GraftOutput( m_Normalizer->GetOutput() );
}
else
{
m_Transformer->GraftOutput( this->GetOutput() );
m_Transformer->Update();
this->GraftOutput( m_Transformer->GetOutput() );
}
}
template < class TInputImage, class TOutputImage,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment