From 22da3813c30f97a44a6ad3f18feb4693a3df5663 Mon Sep 17 00:00:00 2001 From: Jordi Inglada <jordi.inglada@cesbio.cnes.fr> Date: Tue, 8 Jun 2010 14:17:08 +0200 Subject: [PATCH] ENH: modify ConfusionMatrixCalculator interface to be compatible with SVM classifier (list sample types) --- Code/Learning/otbConfusionMatrixCalculator.h | 29 +++---- .../Learning/otbConfusionMatrixCalculator.txx | 24 +++--- .../otbConfusionMatrixCalculatorTest.cxx | 76 +++++++++++++------ 3 files changed, 80 insertions(+), 49 deletions(-) diff --git a/Code/Learning/otbConfusionMatrixCalculator.h b/Code/Learning/otbConfusionMatrixCalculator.h index 872cc291e9..76108653d1 100644 --- a/Code/Learning/otbConfusionMatrixCalculator.h +++ b/Code/Learning/otbConfusionMatrixCalculator.h @@ -32,7 +32,7 @@ namespace otb * \brief TODO * */ -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > class ITK_EXPORT ConfusionMatrixCalculator : public itk::ProcessObject { @@ -50,8 +50,11 @@ public: itkNewMacro(Self); /** List to store the corresponding labels */ - typedef TListLabel ListLabelType; - typedef typename ListLabelType::Pointer ListLabelPointerType; + typedef TRefListLabel RefListLabelType; + typedef typename RefListLabelType::Pointer RefListLabelPointerType; + + typedef TProdListLabel ProdListLabelType; + typedef typename ProdListLabelType::Pointer ProdListLabelPointerType; /** Type for the confusion matrix */ typedef itk::VariableSizeMatrix<double> ConfusionMatrixType; @@ -60,10 +63,10 @@ public: /** Accessors */ - itkSetObjectMacro(ReferenceLabels, ListLabelType); - itkGetConstObjectMacro(ReferenceLabels, ListLabelType); - itkSetObjectMacro(ProducedLabels, ListLabelType); - itkGetConstObjectMacro(ProducedLabels, ListLabelType); + itkSetObjectMacro(ReferenceLabels, RefListLabelType); + itkGetConstObjectMacro(ReferenceLabels, RefListLabelType); + itkSetObjectMacro(ProducedLabels, ProdListLabelType); + itkGetConstObjectMacro(ProducedLabels, ProdListLabelType); itkGetConstMacro(KappaIndex, double); itkGetConstMacro(OverallAccuracy, double); itkGetConstMacro(NumberOfClasses, unsigned short); @@ -81,17 +84,17 @@ private: ConfusionMatrixCalculator(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented - double m_KappaIndex; - double m_OverallAccuracy; + double m_KappaIndex; + double m_OverallAccuracy; std::map<int,int> m_MapOfClasses; - unsigned short m_NumberOfClasses; + unsigned short m_NumberOfClasses; - ConfusionMatrixType m_ConfusionMatrix; + ConfusionMatrixType m_ConfusionMatrix; - ListLabelPointerType m_ReferenceLabels; - ListLabelPointerType m_ProducedLabels; + RefListLabelPointerType m_ReferenceLabels; + ProdListLabelPointerType m_ProducedLabels; }; }// end of namespace otb diff --git a/Code/Learning/otbConfusionMatrixCalculator.txx b/Code/Learning/otbConfusionMatrixCalculator.txx index 1200dfbc4f..648b9d84a1 100644 --- a/Code/Learning/otbConfusionMatrixCalculator.txx +++ b/Code/Learning/otbConfusionMatrixCalculator.txx @@ -23,8 +23,8 @@ namespace otb { -template<class TListLabel> -ConfusionMatrixCalculator<TListLabel> +template<class TRefListLabel, class TProdListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::ConfusionMatrixCalculator() : m_KappaIndex(0.0), m_OverallAccuracy(0.0), m_NumberOfClasses(0) { @@ -32,28 +32,28 @@ ConfusionMatrixCalculator<TListLabel> this->SetNumberOfRequiredOutputs(1); m_ConfusionMatrix = ConfusionMatrixType(m_NumberOfClasses,m_NumberOfClasses); m_ConfusionMatrix.Fill(0); - m_ReferenceLabels = ListLabelType::New(); - m_ProducedLabels = ListLabelType::New(); + m_ReferenceLabels = RefListLabelType::New(); + m_ProducedLabels = ProdListLabelType::New(); } -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > void -ConfusionMatrixCalculator<TListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::Update() { this->GenerateData(); } -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > void -ConfusionMatrixCalculator<TListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::GenerateData() { - typename ListLabelType::ConstIterator refIterator = m_ReferenceLabels->Begin(); - typename ListLabelType::ConstIterator prodIterator = m_ProducedLabels->Begin(); + typename RefListLabelType::ConstIterator refIterator = m_ReferenceLabels->Begin(); + typename ProdListLabelType::ConstIterator prodIterator = m_ProducedLabels->Begin(); //check that both lists have the same number of samples @@ -105,9 +105,9 @@ ConfusionMatrixCalculator<TListLabel> } -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > void -ConfusionMatrixCalculator<TListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::PrintSelf(std::ostream& os, itk::Indent indent) const { os << indent << "TODO"; diff --git a/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx b/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx index fb7555a785..b61df7d81c 100644 --- a/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx +++ b/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx @@ -22,9 +22,12 @@ int otbConfusionMatrixCalculatorNew(int argc, char* argv[]) { - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; CalculatorType::Pointer calculator = CalculatorType::New(); @@ -39,14 +42,18 @@ int otbConfusionMatrixCalculatorSetListSamples(int argc, char* argv[]) std::cerr << "Usage: " << argv[0] << " nbSamples nbClasses " << std::endl; return EXIT_FAILURE; } - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; + CalculatorType::Pointer calculator = CalculatorType::New(); - ListLabelType::Pointer refLabels = ListLabelType::New(); - ListLabelType::Pointer prodLabels = ListLabelType::New(); + RListLabelType::Pointer refLabels = RListLabelType::New(); + PListLabelType::Pointer prodLabels = PListLabelType::New(); int nbSamples = atoi(argv[1]); int nbClasses = atoi(argv[2]); @@ -55,8 +62,11 @@ int otbConfusionMatrixCalculatorSetListSamples(int argc, char* argv[]) for(int i=0; i<nbSamples; i++) { int label = (i%nbClasses)+1; + PLabelType plab; + plab.SetSize(1); + plab[0] = label; refLabels->PushBack( label ); - prodLabels->PushBack( label ); + prodLabels->PushBack( plab ); } calculator->SetReferenceLabels( refLabels ); @@ -68,31 +78,42 @@ int otbConfusionMatrixCalculatorSetListSamples(int argc, char* argv[]) int otbConfusionMatrixCalculatorWrongSize(int argc, char* argv[]) { - if( argc!= 3) + if( argc!= 3) { std::cerr << "Usage: " << argv[0] << " nbSamples nbClasses " << std::endl; return EXIT_FAILURE; } - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; CalculatorType::Pointer calculator = CalculatorType::New(); - ListLabelType::Pointer refLabels = ListLabelType::New(); - ListLabelType::Pointer prodLabels = ListLabelType::New(); + RListLabelType::Pointer refLabels = RListLabelType::New(); + PListLabelType::Pointer prodLabels = PListLabelType::New(); int nbSamples = atoi(argv[1]); int nbClasses = atoi(argv[2]); + for(int i=0; i<nbSamples; i++) { int label = (i%nbClasses)+1; + PLabelType plab; + plab.SetSize(1); + plab[0] = label; refLabels->PushBack( label ); - prodLabels->PushBack( label ); + prodLabels->PushBack( plab ); } - prodLabels->PushBack( 0 ); + PLabelType plab; + plab.SetSize(1); + plab[0] = 0; + prodLabels->PushBack( plab ); calculator->SetReferenceLabels( refLabels ); calculator->SetProducedLabels( prodLabels ); @@ -118,14 +139,18 @@ int otbConfusionMatrixCalculatorUpdate(int argc, char* argv[]) std::cerr << "Usage: " << argv[0] << " nbSamples nbClasses " << std::endl; return EXIT_FAILURE; } - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; CalculatorType::Pointer calculator = CalculatorType::New(); - ListLabelType::Pointer refLabels = ListLabelType::New(); - ListLabelType::Pointer prodLabels = ListLabelType::New(); + RListLabelType::Pointer refLabels = RListLabelType::New(); + PListLabelType::Pointer prodLabels = PListLabelType::New(); int nbSamples = atoi(argv[1]); int nbClasses = atoi(argv[2]); @@ -134,11 +159,14 @@ int otbConfusionMatrixCalculatorUpdate(int argc, char* argv[]) for(int i=0; i<nbSamples; i++) { int label = (i%nbClasses)+1; + PLabelType plab; + plab.SetSize(1); + plab[0] = label; refLabels->PushBack( label ); - prodLabels->PushBack( label ); + prodLabels->PushBack( plab ); } - + calculator->SetReferenceLabels( refLabels ); calculator->SetProducedLabels( prodLabels ); -- GitLab