Skip to content
Snippets Groups Projects
Commit c61ea7fa authored by Grégoire Mercier's avatar Grégoire Mercier
Browse files

ENH: complete FastICA with threaded optimizer

parent 2a2b6397
Branches
Tags
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbFastICAInternalOptimizerVectorImageFilter_h
#define __otbFastICAInternalOptimizerVectorImageFilter_h
#include "otbMacro.h"
#include "itkImageToImageFilter.h"
#include "otbMatrixMultiplyImageFilter.h"
#include "otbStreamingStatisticsVectorImageFilter2.h"
namespace otb
{
/** \class FastICAInternalOptimizerVectorImageFilter
* \brief Internal optimisation of the FastICA unmixing filter
*
* This class implements the internal search for the unmixing matrix W
* in the FastICA technique.
*
* The class takes 2 inputs (initial image and its projection with the W matrix).
*
* \ingroup Multithreaded
* \sa FastICAImageFilter
*/
template <class TInputImage, class TOutputImage>
class ITK_EXPORT FastICAInternalOptimizerVectorImageFilter
: public itk::ImageToImageFilter<TInputImage, TOutputImage>
{
public:
/** Standard typedefs */
typedef FastICAInternalOptimizerVectorImageFilter Self;
typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Type macro */
itkNewMacro(Self);
/** Creation through object factory macro */
itkTypeMacro(FastICAInternalOptimizerVectorImageFilter, ImageToImageFilter);
/** Dimension */
itkStaticConstMacro(InputImageDimension, unsigned int, TInputImage::ImageDimension);
itkStaticConstMacro(OutputImageDimension, unsigned int, TOutputImage::ImageDimension);
/** Template parameters typedefs */
typedef TInputImage InputImageType;
typedef typename InputImageType::RegionType InputRegionType;
typedef TOutputImage OutputImageType;
typedef typename OutputImageType::RegionType OutputRegionType;
/** Filter types and related */
typedef StreamingStatisticsVectorImageFilter2< InputImageType > CovarianceEstimatorFilterType;
typedef typename CovarianceEstimatorFilterType::Pointer CovarianceEstimatorFilterPointerType;
typedef typename CovarianceEstimatorFilterType::RealType RealType;
typedef typename CovarianceEstimatorFilterType::RealPixelType VectorType;
typedef typename CovarianceEstimatorFilterType::MatrixObjectType MatrixObjectType;
typedef typename MatrixObjectType::ComponentType MatrixType;
typedef typename MatrixType::InternalMatrixType InternalMatrixType;
typedef typename InternalMatrixType::element_type MatrixElementType;
typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType;
typedef typename TransformFilterType::Pointer TransformFilterPointerType;
typedef double (*ContrastFunctionType) ( double );
itkSetMacro(CurrentBandForLoop,unsigned int);
itkGetMacro(CurrentBandForLoop,unsigned int);
itkGetMacro(W,InternalMatrixType);
itkSetMacro(W,InternalMatrixType);
itkSetMacro(ContrastFunction,ContrastFunctionType);
itkGetMacro(Beta,double);
itkGetMacro(Den,double);
protected:
FastICAInternalOptimizerVectorImageFilter();
virtual ~FastICAInternalOptimizerVectorImageFilter() { }
virtual void GenerateOutputInformation();
virtual void BeforeThreadedGenerateData ();
virtual void ThreadedGenerateData ( const OutputRegionType &, int );
virtual void AfterThreadedGenerateData();
unsigned int m_CurrentBandForLoop;
std::vector<double> m_BetaVector;
std::vector<double> m_DenVector;
std::vector<double> m_NbSamples;
double m_Beta;
double m_Den;
InternalMatrixType m_W;
ContrastFunctionType m_ContrastFunction;
TransformFilterPointerType m_TransformFilter;
private:
FastICAInternalOptimizerVectorImageFilter( const Self & ); // not implemented
void operator= ( const Self & ); // not implemented
}; // end of class
} // end of namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbFastICAInternalOptimizerVectorImageFilter.txx"
#endif
#endif
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#ifndef __otbFastICAInternalOptimizerVectorImageFilter_txx
#define __otbFastICAInternalOptimizerVectorImageFilter_txx
#include "otbFastICAInternalOptimizerVectorImageFilter.h"
#include <itkExceptionObject.h>
#include <itkImageRegionConstIterator.h>
#include <itkImageRegionIterator.h>
#include <vnl/vnl_math.h>
#include <vnl/vnl_matrix.h>
namespace otb
{
template <class TInputImage, class TOutputImage>
FastICAInternalOptimizerVectorImageFilter< TInputImage, TOutputImage >
::FastICAInternalOptimizerVectorImageFilter()
{
this->SetNumberOfRequiredInputs(2);
m_CurrentBandForLoop = 0;
m_Beta = 0.;
m_Den = 0.;
m_ContrastFunction = &vcl_tanh;
m_TransformFilter = TransformFilterType::New();
}
template <class TInputImage, class TOutputImage>
void
FastICAInternalOptimizerVectorImageFilter< TInputImage, TOutputImage >
::GenerateOutputInformation()
{
Superclass::GenerateOutputInformation();
this->GetOutput()->SetNumberOfComponentsPerPixel(
this->GetInput(0)->GetNumberOfComponentsPerPixel() );
}
template <class TInputImage, class TOutputImage>
void
FastICAInternalOptimizerVectorImageFilter< TInputImage, TOutputImage >
::BeforeThreadedGenerateData ()
{
if ( m_W.empty() )
{
throw itk::ExceptionObject( __FILE__, __LINE__,
"Give the initial W matrix", ITK_LOCATION );
}
m_BetaVector.resize( this->GetNumberOfThreads() );
m_DenVector.resize( this->GetNumberOfThreads() );
m_NbSamples.resize( this->GetNumberOfThreads() );
}
template <class TInputImage, class TOutputImage>
void
FastICAInternalOptimizerVectorImageFilter< TInputImage, TOutputImage >
::ThreadedGenerateData ( const OutputRegionType & outputRegionForThread, int threadId )
{
InputRegionType inputRegion;
this->CallCopyOutputRegionToInputRegion( inputRegion, outputRegionForThread );
itk::ImageRegionConstIterator< InputImageType > input0It
( this->GetInput(0), inputRegion );
itk::ImageRegionConstIterator< InputImageType > input1It
( this->GetInput(1), inputRegion );
itk::ImageRegionIterator< OutputImageType > outputIt
( this->GetOutput(), outputRegionForThread );
unsigned int nbBands = this->GetInput(0)->GetNumberOfComponentsPerPixel();
input0It.GoToBegin();
input1It.GoToBegin();
outputIt.GoToBegin();
double beta = 0.;
double den = 0.;
double nbSample = 0.;
while ( !input0It.IsAtEnd() && !input1It.IsAtEnd() && !outputIt.IsAtEnd() )
{
double x = static_cast<double>( input1It.Get()[GetCurrentBandForLoop()] );
double g_x = (*m_ContrastFunction)(x);
double x_g_x = x * g_x;
beta += x_g_x;
double gp = 1. - vcl_pow( g_x, 2. );
den += gp;
nbSample += 1.;
typename OutputImageType::PixelType z ( nbBands );
for ( unsigned int bd = 0; bd < nbBands; bd++ )
z[bd] = x * input0It.Get()[bd];
outputIt.Set(z);
++input0It;
++input1It;
++outputIt;
} // end while loop
m_BetaVector[threadId] = beta;
m_DenVector[threadId] = den;
m_NbSamples[threadId] = nbSample;
}
template <class TInputImage, class TOutputImage>
void
FastICAInternalOptimizerVectorImageFilter< TInputImage, TOutputImage >
::AfterThreadedGenerateData ()
{
double beta = 0;
double den = 0.;
double nbSample = 0;
for ( int i = 0; i < this->GetNumberOfThreads(); i++ )
{
beta += m_BetaVector[i];
den += m_DenVector[i];
nbSample += m_NbSamples[i];
}
m_Beta = beta / nbSample;
m_Den = den / nbSample - m_Beta;
}
} // end of namespace otb
#endif
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
#include "itkImageToImageFilter.h" #include "itkImageToImageFilter.h"
#include "otbPCAImageFilter.h" #include "otbPCAImageFilter.h"
#include "otbMatrixMultiplyImageFilter.h" #include "otbMatrixMultiplyImageFilter.h"
#include "otbStreamingStatisticsVectorImageFilter2.h"
#include "otbFastICAInternalOptimizerVectorImageFilter.h"
namespace otb namespace otb
{ {
...@@ -74,8 +76,14 @@ public: ...@@ -74,8 +76,14 @@ public:
typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType; typedef MatrixMultiplyImageFilter< TInputImage, TOutputImage, RealType > TransformFilterType;
typedef typename TransformFilterType::Pointer TransformFilterPointerType; typedef typename TransformFilterType::Pointer TransformFilterPointerType;
typedef double (*ContrastFunctionType) ( double ); typedef FastICAInternalOptimizerVectorImageFilter< InputImageType, InputImageType >
InternalOptimizerType;
typedef typename InternalOptimizerType::Pointer InternalOptimizerPointerType;
typedef StreamingStatisticsVectorImageFilter2< InputImageType > MeanEstimatorFilterType;
typedef typename MeanEstimatorFilterType::Pointer MeanEstimatorFilterPointerType;
typedef double (*ContrastFunctionType) ( double );
/** /**
* Set/Get the number of required largest principal components. * Set/Get the number of required largest principal components.
...@@ -85,11 +93,9 @@ public: ...@@ -85,11 +93,9 @@ public:
itkGetConstMacro(PCAFilter,PCAFilterType *); itkGetConstMacro(PCAFilter,PCAFilterType *);
itkGetMacro(PCAFilter,PCAFilterType *); itkGetMacro(PCAFilter,PCAFilterType *);
itkSetMacro(PCAFilter,PCAFilterType *);
itkGetConstMacro(TransformFilter,TransformFilterType *); itkGetConstMacro(TransformFilter,TransformFilterType *);
itkGetMacro(TransformFilter,TransformFilterType *); itkGetMacro(TransformFilter,TransformFilterType *);
itkSetMacro(TransformFilter,TransformFilterType *);
VectorType GetMeanValues () const VectorType GetMeanValues () const
{ {
...@@ -113,7 +119,7 @@ public: ...@@ -113,7 +119,7 @@ public:
{ {
return this->GetPCAFilter()->GetTransformationMatrix(); return this->GetPCAFilter()->GetTransformationMatrix();
} }
void SetPCACovarianceMatrix ( const MatrixType & mat, bool isForward = true ) void SetPCATransformationMatrix ( const MatrixType & mat, bool isForward = true )
{ {
m_PCAFilter->SetTransformationMatrix(mat,isForward); m_PCAFilter->SetTransformationMatrix(mat,isForward);
} }
...@@ -128,14 +134,17 @@ public: ...@@ -128,14 +134,17 @@ public:
this->Modified(); this->Modified();
} }
itkGetMacro(MaximumOfIterations,unsigned int); itkGetMacro(NumberOfIterations,unsigned int);
itkSetMacro(MaximumOfIterations,unsigned int); itkSetMacro(NumberOfIterations,unsigned int);
itkGetMacro(ConvergenceThreshold,double); itkGetMacro(ConvergenceThreshold,double);
itkSetMacro(ConvergenceThreshold,double); itkSetMacro(ConvergenceThreshold,double);
itkGetMacro(ContrastFunction,ContrastFunctionType); itkGetMacro(ContrastFunction,ContrastFunctionType);
itkGetMacro(Mu,double);
itkSetMacro(Mu,double);
protected: protected:
FastICAImageFilter (); FastICAImageFilter ();
virtual ~FastICAImageFilter() { } virtual ~FastICAImageFilter() { }
...@@ -155,8 +164,6 @@ protected: ...@@ -155,8 +164,6 @@ protected:
*/ */
virtual void GenerateData (); virtual void GenerateData ();
void PrintSelf(std::ostream& os, itk::Indent indent) const;
/** Internal methods */ /** Internal methods */
virtual void ForwardGenerateData(); virtual void ForwardGenerateData();
virtual void ReverseGenerateData(); virtual void ReverseGenerateData();
...@@ -172,9 +179,10 @@ protected: ...@@ -172,9 +179,10 @@ protected:
MatrixType m_TransformationMatrix; MatrixType m_TransformationMatrix;
/** FastICA parameters */ /** FastICA parameters */
unsigned int m_MaximumOfIterations; // def is 50 unsigned int m_NumberOfIterations; // def is 50
double m_ConvergenceThreshold; // def is 1e-4 double m_ConvergenceThreshold; // def is 1e-4
ContrastFunctionType m_ConstrastFunction; // see g() function in the biblio. Def is tanh ContrastFunctionType m_ContrastFunction; // see g() function in the biblio. Def is tanh
double m_Mu; // def is 1. in [0,1]
PCAFilterPointerType m_PCAFilter; PCAFilterPointerType m_PCAFilter;
TransformFilterPointerType m_TransformFilter; TransformFilterPointerType m_TransformFilter;
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "itkExceptionObject.h" #include "itkExceptionObject.h"
#include "itkNumericTraits.h" #include "itkNumericTraits.h"
#include "itkProgressReporter.h"
#include <vnl/vnl_matrix.h> #include <vnl/vnl_matrix.h>
#include <vnl/algo/vnl_matrix_inverse.h> #include <vnl/algo/vnl_matrix_inverse.h>
...@@ -43,12 +44,13 @@ FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation > ...@@ -43,12 +44,13 @@ FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
m_GivenTransformationMatrix = false; m_GivenTransformationMatrix = false;
m_IsTransformationForward = true; m_IsTransformationForward = true;
m_MaximumOfIterations = 50; m_NumberOfIterations = 50;
m_ConvergenceThreshold = 1E-4; m_ConvergenceThreshold = 1E-4;
m_ConstrastFunction = &vcl_tanh; m_ContrastFunction = &vcl_tanh;
m_Mu = 1.;
m_PCAFilter = PCAFilterType::New(); m_PCAFilter = PCAFilterType::New();
m_PCAFilter->UseNormalization(); m_PCAFilter->SetUseNormalization(true);
m_TransformFilter = TransformFilterType::New(); m_TransformFilter = TransformFilterType::New();
} }
...@@ -86,11 +88,6 @@ FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation > ...@@ -86,11 +88,6 @@ FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
theOutputDimension = m_TransformationMatrix.Rows() >= m_TransformationMatrix.Cols() ? theOutputDimension = m_TransformationMatrix.Rows() >= m_TransformationMatrix.Cols() ?
m_TransformationMatrix.Rows() : m_TransformationMatrix.Cols(); m_TransformationMatrix.Rows() : m_TransformationMatrix.Cols();
} }
else if ( m_GivenCovarianceMatrix )
{
theOutputDimension = m_CovarianceMatrix.Rows() >= m_CovarianceMatrix.Cols() ?
m_CovarianceMatrix.Rows() : m_CovarianceMatrix.Cols();
}
else else
{ {
throw itk::ExceptionObject(__FILE__, __LINE__, throw itk::ExceptionObject(__FILE__, __LINE__,
...@@ -216,31 +213,81 @@ void ...@@ -216,31 +213,81 @@ void
FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation > FastICAImageFilter< TInputImage, TOutputImage, TDirectionOfTransformation >
::GenerateTransformationMatrix () ::GenerateTransformationMatrix ()
{ {
double convergence = itk::NumericTraits<double>::Max(); itk::ProgressReporter reporter ( this, 0, GetNumberOfIterations(), GetNumberOfIterations() );
double convergence = itk::NumericTraits<double>::max();
unsigned int iteration = 0; unsigned int iteration = 0;
unsigned int size = this->GetInput()->GetNumberOfComponentsPerPixel(); const unsigned int size = this->GetInput()->GetNumberOfComponentsPerPixel();
// transformation matrix // transformation matrix
InternalMatrixType W ( size, size, vnl_matrix_identity ); InternalMatrixType W ( size, size, vnl_matrix_identity );
while ( iteration++ < GetMaximumOfIterations() while ( iteration++ < GetNumberOfIterations()
&& convergence > GetConvergenceThreshold() ) && convergence > GetConvergenceThreshold() )
{ {
otbMsgDebugMacro( "Iteration " << iteration << " / " << GetMaximumOfIterations()
<< ", MSE = " << convergence );
InternalMatrixType W_old ( W ); InternalMatrixType W_old ( W );
// TODO le premier coup ne sert a rien typename InputImageType::Pointer img = const_cast<InputImageType*>( this->GetInput() );
TransformFilterPointerType transformer = TransformFilterType::New(); TransformFilterPointerType transformer = TransformFilterType::New();
transformer->SetIntput( GetPCAFilter()->GetOutput() ); if ( !W.is_identity() )
transformer->SetMatrix( W ); {
transformer->Update(); transformer->SetInput( GetPCAFilter()->GetOutput() );
transformer->SetMatrix( W );
transformer->Update();
img = const_cast<InputImageType*>( transformer->GetOutput() );
}
// Faire un image to image filter... for ( unsigned int band = 0; band < size; band++ )
{
InternalOptimizerPointerType optimizer = InternalOptimizerType::New();
optimizer->SetInput( 0, m_PCAFilter->GetOutput() );
optimizer->SetInput( 1, img );
optimizer->SetW( W );
optimizer->SetContrastFunction( this->GetContrastFunction() );
optimizer->SetCurrentBandForLoop( band );
MeanEstimatorFilterPointerType estimator = MeanEstimatorFilterType::New();
estimator->SetInput( optimizer->GetOutput() );
estimator->Update();
double norm = 0.;
for ( unsigned int bd = 0; bd < size; bd++ )
{
W(bd,band) -= m_Mu * ( estimator->GetMean()[bd]
- optimizer->GetBeta() * W(bd,band) / optimizer->GetDen() );
norm += vcl_pow( W(bd,band), 2. );
}
for ( unsigned int bd = 0; bd < size; bd++ )
W(bd,band) /= norm;
}
// Decorrelation of the W vectors
InternalMatrixType W_tmp = W * W.transpose();
InternalMatrixType Id ( W.rows(), W.cols(), vnl_matrix_identity );
vnl_generalized_eigensystem solver ( W_tmp, Id );
InternalMatrixType valP = solver.D;
for ( unsigned int i = 0; i < valP.size(); i++ )
valP(i,i) = 1. / vcl_sqrt( static_cast<double>( valP(i,i) ) ); // Watch for 0 or neg
W_tmp = solver.V * valP * solver.V.transpose();
W = W.transpose() * W;
// Convergence evaluation
convergence = 0.;
for ( unsigned int i = 0; i < W.rows(); i++ )
for ( unsigned int j = 0; j < W.cols(); j++ )
convergence += vcl_abs( W(i,j) - W_old(i,j) );
reporter.CompletedPixel();
} // end of while loop } // end of while loop
if ( size != this->GetNumberOfPrincipalComponentsRequired() )
this->m_TransformationMatrix = W.get_n_rows( 0, this->GetNumberOfPrincipalComponentsRequired() );
else
this->m_TransformationMatrix = W;
otbMsgDebugMacro( << "Final convergence " << convergence
<< " after " << iteration << " iterations" );
} }
} // end of namespace otb } // end of namespace otb
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment