Skip to content
Snippets Groups Projects
Commit 2a2b6397 authored by Grégoire Mercier's avatar Grégoire Mercier
Browse files

ADD: FastICA

parent e2b57423
No related branches found
No related tags found
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbFastICAImageFilter_h
#define __otbFastICAImageFilter_h
#include "otbMacro.h"
#include "itkImageToImageFilter.h"
#include "otbPCAImageFilter.h"
#include "otbMatrixMultiplyImageFilter.h"
namespace otb
{
/** \class FastICAImageFilter
* \brief Performs a Independent Component Analysis
*
* The internal structure of this filter is a filter-to-filter like structure.
* The estimation of the covariance matrix has persistent capabilities...
*
* \sa PCAImageFilter
*/
template <class TInputImage, class TOutputImage, Transform::TransformDirection TDirectionOfTransformation >
class ITK_EXPORT FastICAImageFilter
: public itk::ImageToImageFilter<TInputImage, TOutputImage>
{
public:
/** Standard typedefs */
typedef FastICAImageFilter Self;
typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Type macro */
itkNewMacro(Self);
/** Creation through object factory macro */
itkTypeMacro(FastICAImageFilter, ImageToImageFilter);
/** Dimension */
itkStaticConstMacro(InputImageDimension, unsigned int, TInputImage::ImageDimension);
itkStaticConstMacro(OutputImageDimension, unsigned int, TOutputImage::ImageDimension);
typedef Transform::TransformDirection TransformDirectionEnumType;
itkStaticConstMacro(DirectionOfTransformation,TransformDirectionEnumType,TDirectionOfTransformation);
/** typedefs */
typedef TInputImage InputImageType;
typedef TOutputImage OutputImageType;
typedef PCAImageFilter< InputImageType, OutputImageType, TDirectionOfTransformation >
PCAFilterType;
typedef typename PCAFilterType::Pointer PCAFilterPointerType;
typedef typename PCAFilterType::RealType RealType;
typedef typename PCAFilterType::VectorType VectorType;
typedef typename PCAFilterType::MatrixType MatrixType;
typedef typename MatrixType::InternalMatrixType InternalMatrixType;
typedef typename InternalMatrixType::element_type MatrixElementType;
typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType;
typedef typename TransformFilterType::Pointer TransformFilterPointerType;
typedef double (*ContrastFunctionType) ( double );
/**
* Set/Get the number of required largest principal components.
*/
itkGetMacro(NumberOfPrincipalComponentsRequired,unsigned int);
itkSetMacro(NumberOfPrincipalComponentsRequired,unsigned int);
itkGetConstMacro(PCAFilter,PCAFilterType *);
itkGetMacro(PCAFilter,PCAFilterType *);
itkSetMacro(PCAFilter,PCAFilterType *);
itkGetConstMacro(TransformFilter,TransformFilterType *);
itkGetMacro(TransformFilter,TransformFilterType *);
itkSetMacro(TransformFilter,TransformFilterType *);
VectorType GetMeanValues () const
{
return this->GetPCAFilter()->GetMeanValues();
}
void SetMeanValues ( const VectorType & vec )
{
m_PCAFilter->SetMeanValues(vec);
}
VectorType GetStdDevValues ( ) const
{
return this->GetPCAFilter()->GetStdDevValues();
}
void SetStdDevValues ( const VectorType & vec )
{
m_PCAFilter->SetStdDevValues(vec);
}
MatrixType GetPCATransformationMatrix () const
{
return this->GetPCAFilter()->GetTransformationMatrix();
}
void SetPCACovarianceMatrix ( const MatrixType & mat, bool isForward = true )
{
m_PCAFilter->SetTransformationMatrix(mat,isForward);
}
itkGetConstMacro(TransformationMatrix,MatrixType);
itkGetMacro(TransformationMatrix,MatrixType);
void SetTransformationMatrix ( const MatrixType & mat, bool isForward = true )
{
m_IsTransformationForward = isForward;
m_GivenTransformationMatrix = true;
m_TransformationMatrix = mat;
this->Modified();
}
itkGetMacro(MaximumOfIterations,unsigned int);
itkSetMacro(MaximumOfIterations,unsigned int);
itkGetMacro(ConvergenceThreshold,double);
itkSetMacro(ConvergenceThreshold,double);
itkGetMacro(ContrastFunction,ContrastFunctionType);
protected:
FastICAImageFilter ();
virtual ~FastICAImageFilter() { }
/** GenerateOutputInformation
* Propagate vector length info and modify if needed
* NumberOfPrincipalComponentsRequired
*
* In REVERSE mode, the covariance matrix or the transformation matrix
* (which may not be square) has to be given,
* otherwize, GenerateOutputInformation throws an itk::ExceptionObject
*/
virtual void GenerateOutputInformation();
/** GenerateData
* Through a filter of filter structure
*/
virtual void GenerateData ();
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Internal methods */
virtual void ForwardGenerateData();
virtual void ReverseGenerateData();
/** this is the specifical part of FastICA */
virtual void GenerateTransformationMatrix();
unsigned int m_NumberOfPrincipalComponentsRequired;
/** Transformation matrix refers to the ICA step (not PCA) */
bool m_GivenTransformationMatrix;
bool m_IsTransformationForward;
MatrixType m_TransformationMatrix;
/** FastICA parameters */
unsigned int m_MaximumOfIterations; // def is 50
double m_ConvergenceThreshold; // def is 1e-4
ContrastFunctionType m_ConstrastFunction; // see g() function in the biblio. Def is tanh
PCAFilterPointerType m_PCAFilter;
TransformFilterPointerType m_TransformFilter;
private:
FastICAImageFilter ( const Self & );
void operator= ( const Self & );
}; // end of class
} // end of namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbFastICAImageFilter.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbFastICAImageFilter_txx
#define __otbFastICAImageFilter_txx
#include "otbFastICAImageFilter.h"
#include "otbMacro.h"
#include "itkExceptionObject.h"
#include "itkNumericTraits.h"
#include <vnl/vnl_matrix.h>
#include <vnl/algo/vnl_matrix_inverse.h>
#include <vnl/algo/vnl_generalized_eigensystem.h>
namespace otb
{
template < class TInputImage, class TOutputImage,
Transform::TransformDirection TDirectionOfTransformation >
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::FastICAImageFilter ()
{
this->SetNumberOfRequiredInputs(1);
m_NumberOfPrincipalComponentsRequired = 0;
m_GivenTransformationMatrix = false;
m_IsTransformationForward = true;
m_MaximumOfIterations = 50;
m_ConvergenceThreshold = 1E-4;
m_ConstrastFunction = &vcl_tanh;
m_PCAFilter = PCAFilterType::New();
m_PCAFilter->UseNormalization();
m_TransformFilter = TransformFilterType::New();
}
template < class TInputImage, class TOutputImage,
Transform::TransformDirection TDirectionOfTransformation >
void
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::GenerateOutputInformation()
// throw itk::ExceptionObject
{
Superclass::GenerateOutputInformation();
switch ( DirectionOfTransformation )
{
case Transform::FORWARD:
{
if ( m_NumberOfPrincipalComponentsRequired == 0
|| m_NumberOfPrincipalComponentsRequired
> this->GetInput()->GetNumberOfComponentsPerPixel() )
{
m_NumberOfPrincipalComponentsRequired =
this->GetInput()->GetNumberOfComponentsPerPixel();
}
this->GetOutput()->SetNumberOfComponentsPerPixel(
m_NumberOfPrincipalComponentsRequired );
break;
}
case Transform::INVERSE:
{
unsigned int theOutputDimension = 0;
if ( m_GivenTransformationMatrix )
{
theOutputDimension = m_TransformationMatrix.Rows() >= m_TransformationMatrix.Cols() ?
m_TransformationMatrix.Rows() : m_TransformationMatrix.Cols();
}
else if ( m_GivenCovarianceMatrix )
{
theOutputDimension = m_CovarianceMatrix.Rows() >= m_CovarianceMatrix.Cols() ?
m_CovarianceMatrix.Rows() : m_CovarianceMatrix.Cols();
}
else
{
throw itk::ExceptionObject(__FILE__, __LINE__,
"Mixture matrix is required to know the output size",
ITK_LOCATION);
}
this->GetOutput()->SetNumberOfComponentsPerPixel( theOutputDimension );
break;
}
default:
throw itk::ExceptionObject(__FILE__, __LINE__,
"Class should be templeted with FORWARD or INVERSE only...",
ITK_LOCATION );
}
}
template < class TInputImage, class TOutputImage,
Transform::TransformDirection TDirectionOfTransformation >
void
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::GenerateData ()
{
switch ( DirectionOfTransformation )
{
case Transform::FORWARD:
return ForwardGenerateData();
case Transform::INVERSE:
return ReverseGenerateData();
default:
throw itk::ExceptionObject(__FILE__, __LINE__,
"Class should be templated with FORWARD or INVERSE only...",
ITK_LOCATION );
}
}
template < class TInputImage, class TOutputImage,
Transform::TransformDirection TDirectionOfTransformation >
void
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::ForwardGenerateData ()
{
typename InputImageType::Pointer inputImgPtr
= const_cast<InputImageType*>( this->GetInput() );
m_PCAFilter->SetInput( inputImgPtr );
m_PCAFilter->Update();
if ( !m_GivenTransformationMatrix )
{
GenerateTransformationMatrix();
}
else if ( !m_IsTransformationForward )
{
// prevent from multiple inversion in the pipelines
m_IsTransformationForward = true;
vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
m_TransformationMatrix = invertor.pinverse();
}
if ( m_TransformationMatrix.GetVnlMatrix().empty() )
{
throw itk::ExceptionObject( __FILE__, __LINE__,
"Empty transformation matrix",
ITK_LOCATION);
}
m_TransformFilter->SetInput( m_PCAFilter->GetOutput() );
m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
m_TransformFilter->GraftOutput( this->GetOutput() );
m_TransformFilter->Update();
this->GraftOutput( m_TransformFilter->GetOutput() );
}
template < class TInputImage, class TOutputImage,
Transform::TransformDirection TDirectionOfTransformation >
void
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::ReverseGenerateData ()
{
if ( !m_GivenTransformationMatrix )
{
throw itk::ExceptionObject( __FILE__, __LINE__,
"No Transformation matrix given",
ITK_LOCATION );
}
if ( m_TransformationMatrix.GetVnlMatrix().empty() )
{
throw itk::ExceptionObject( __FILE__, __LINE__,
"Empty transformation matrix",
ITK_LOCATION);
}
if ( m_IsTransformationForward )
{
// prevent from multiple inversion in the pipelines
m_IsTransformationForward = false;
vnl_svd< MatrixElementType > invertor ( m_TransformationMatrix.GetVnlMatrix() );
m_TransformationMatrix = invertor.pinverse();
}
m_TransformFilter->SetInput( this->GetInput() );
m_TransformFilter->SetMatrix( m_TransformationMatrix.GetVnlMatrix() );
/*
* PCA filter may throw exception if
* the mean, stdDev and transformation matrix
* have not been given at this point
*/
m_PCAFilter->SetInput( m_TransformFilter->GetOutput() );
m_PCAFilter->GraftOutput( this->GetOutput() );
m_PCAFilter->Update();
this->GraftOutput( m_PCAFilter->GetOutput() );
}
template < class TInputImage, class TOutputImage,
Transform::TransformDirection TDirectionOfTransformation >
void
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::GenerateTransformationMatrix ()
{
double convergence = itk::NumericTraits<double>::Max();
unsigned int iteration = 0;
unsigned int size = this->GetInput()->GetNumberOfComponentsPerPixel();
// transformation matrix
InternalMatrixType W ( size, size, vnl_matrix_identity );
while ( iteration++ < GetMaximumOfIterations()
&& convergence > GetConvergenceThreshold() )
{
otbMsgDebugMacro( "Iteration " << iteration << " / " << GetMaximumOfIterations()
<< ", MSE = " << convergence );
InternalMatrixType W_old ( W );
// TODO le premier coup ne sert a rien
TransformFilterPointerType transformer = TransformFilterType::New();
transformer->SetIntput( GetPCAFilter()->GetOutput() );
transformer->SetMatrix( W );
transformer->Update();
// Faire un image to image filter...
} // end of while loop
}
} // end of namespace otb
#endif
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment