diff --git a/Code/Learning/otbConfusionMatrixCalculator.h b/Code/Learning/otbConfusionMatrixCalculator.h index 872cc291e9eb2a95e8b1461c9b85610f65f01a5d..76108653d1fe3c742bdfe8bfad2d7eb6ce917b1b 100644 --- a/Code/Learning/otbConfusionMatrixCalculator.h +++ b/Code/Learning/otbConfusionMatrixCalculator.h @@ -32,7 +32,7 @@ namespace otb * \brief TODO * */ -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > class ITK_EXPORT ConfusionMatrixCalculator : public itk::ProcessObject { @@ -50,8 +50,11 @@ public: itkNewMacro(Self); /** List to store the corresponding labels */ - typedef TListLabel ListLabelType; - typedef typename ListLabelType::Pointer ListLabelPointerType; + typedef TRefListLabel RefListLabelType; + typedef typename RefListLabelType::Pointer RefListLabelPointerType; + + typedef TProdListLabel ProdListLabelType; + typedef typename ProdListLabelType::Pointer ProdListLabelPointerType; /** Type for the confusion matrix */ typedef itk::VariableSizeMatrix<double> ConfusionMatrixType; @@ -60,10 +63,10 @@ public: /** Accessors */ - itkSetObjectMacro(ReferenceLabels, ListLabelType); - itkGetConstObjectMacro(ReferenceLabels, ListLabelType); - itkSetObjectMacro(ProducedLabels, ListLabelType); - itkGetConstObjectMacro(ProducedLabels, ListLabelType); + itkSetObjectMacro(ReferenceLabels, RefListLabelType); + itkGetConstObjectMacro(ReferenceLabels, RefListLabelType); + itkSetObjectMacro(ProducedLabels, ProdListLabelType); + itkGetConstObjectMacro(ProducedLabels, ProdListLabelType); itkGetConstMacro(KappaIndex, double); itkGetConstMacro(OverallAccuracy, double); itkGetConstMacro(NumberOfClasses, unsigned short); @@ -81,17 +84,17 @@ private: ConfusionMatrixCalculator(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented - double m_KappaIndex; - double m_OverallAccuracy; + double m_KappaIndex; + double m_OverallAccuracy; std::map<int,int> m_MapOfClasses; - unsigned short m_NumberOfClasses; + unsigned short m_NumberOfClasses; - ConfusionMatrixType m_ConfusionMatrix; + ConfusionMatrixType m_ConfusionMatrix; - ListLabelPointerType m_ReferenceLabels; - ListLabelPointerType m_ProducedLabels; + RefListLabelPointerType m_ReferenceLabels; + ProdListLabelPointerType m_ProducedLabels; }; }// end of namespace otb diff --git a/Code/Learning/otbConfusionMatrixCalculator.txx b/Code/Learning/otbConfusionMatrixCalculator.txx index 1200dfbc4fc60d932fdd48a769da428af0a387a2..0f8cb8d21949161c049830a7ea3bea50dae28a0e 100644 --- a/Code/Learning/otbConfusionMatrixCalculator.txx +++ b/Code/Learning/otbConfusionMatrixCalculator.txx @@ -23,8 +23,8 @@ namespace otb { -template<class TListLabel> -ConfusionMatrixCalculator<TListLabel> +template<class TRefListLabel, class TProdListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::ConfusionMatrixCalculator() : m_KappaIndex(0.0), m_OverallAccuracy(0.0), m_NumberOfClasses(0) { @@ -32,28 +32,28 @@ ConfusionMatrixCalculator<TListLabel> this->SetNumberOfRequiredOutputs(1); m_ConfusionMatrix = ConfusionMatrixType(m_NumberOfClasses,m_NumberOfClasses); m_ConfusionMatrix.Fill(0); - m_ReferenceLabels = ListLabelType::New(); - m_ProducedLabels = ListLabelType::New(); + m_ReferenceLabels = RefListLabelType::New(); + m_ProducedLabels = ProdListLabelType::New(); } -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > void -ConfusionMatrixCalculator<TListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::Update() { this->GenerateData(); } -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > void -ConfusionMatrixCalculator<TListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::GenerateData() { - typename ListLabelType::ConstIterator refIterator = m_ReferenceLabels->Begin(); - typename ListLabelType::ConstIterator prodIterator = m_ProducedLabels->Begin(); + typename RefListLabelType::ConstIterator refIterator = m_ReferenceLabels->Begin(); + typename ProdListLabelType::ConstIterator prodIterator = m_ProducedLabels->Begin(); //check that both lists have the same number of samples @@ -83,6 +83,7 @@ ConfusionMatrixCalculator<TListLabel> m_NumberOfClasses = countClasses; m_ConfusionMatrix = ConfusionMatrixType(m_NumberOfClasses, m_NumberOfClasses); + m_ConfusionMatrix.Fill(0); refIterator = m_ReferenceLabels->Begin(); prodIterator = m_ProducedLabels->Begin(); @@ -105,9 +106,9 @@ ConfusionMatrixCalculator<TListLabel> } -template < class TListLabel > +template < class TRefListLabel, class TProdListLabel > void -ConfusionMatrixCalculator<TListLabel> +ConfusionMatrixCalculator<TRefListLabel,TProdListLabel> ::PrintSelf(std::ostream& os, itk::Indent indent) const { os << indent << "TODO"; diff --git a/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx b/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx index fb7555a7855a93b53c955d168d2763599b8d85fc..b61df7d81c92c79e21fb5148485059249f67b239 100644 --- a/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx +++ b/Testing/Code/Learning/otbConfusionMatrixCalculatorTest.cxx @@ -22,9 +22,12 @@ int otbConfusionMatrixCalculatorNew(int argc, char* argv[]) { - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; CalculatorType::Pointer calculator = CalculatorType::New(); @@ -39,14 +42,18 @@ int otbConfusionMatrixCalculatorSetListSamples(int argc, char* argv[]) std::cerr << "Usage: " << argv[0] << " nbSamples nbClasses " << std::endl; return EXIT_FAILURE; } - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; + CalculatorType::Pointer calculator = CalculatorType::New(); - ListLabelType::Pointer refLabels = ListLabelType::New(); - ListLabelType::Pointer prodLabels = ListLabelType::New(); + RListLabelType::Pointer refLabels = RListLabelType::New(); + PListLabelType::Pointer prodLabels = PListLabelType::New(); int nbSamples = atoi(argv[1]); int nbClasses = atoi(argv[2]); @@ -55,8 +62,11 @@ int otbConfusionMatrixCalculatorSetListSamples(int argc, char* argv[]) for(int i=0; i<nbSamples; i++) { int label = (i%nbClasses)+1; + PLabelType plab; + plab.SetSize(1); + plab[0] = label; refLabels->PushBack( label ); - prodLabels->PushBack( label ); + prodLabels->PushBack( plab ); } calculator->SetReferenceLabels( refLabels ); @@ -68,31 +78,42 @@ int otbConfusionMatrixCalculatorSetListSamples(int argc, char* argv[]) int otbConfusionMatrixCalculatorWrongSize(int argc, char* argv[]) { - if( argc!= 3) + if( argc!= 3) { std::cerr << "Usage: " << argv[0] << " nbSamples nbClasses " << std::endl; return EXIT_FAILURE; } - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; CalculatorType::Pointer calculator = CalculatorType::New(); - ListLabelType::Pointer refLabels = ListLabelType::New(); - ListLabelType::Pointer prodLabels = ListLabelType::New(); + RListLabelType::Pointer refLabels = RListLabelType::New(); + PListLabelType::Pointer prodLabels = PListLabelType::New(); int nbSamples = atoi(argv[1]); int nbClasses = atoi(argv[2]); + for(int i=0; i<nbSamples; i++) { int label = (i%nbClasses)+1; + PLabelType plab; + plab.SetSize(1); + plab[0] = label; refLabels->PushBack( label ); - prodLabels->PushBack( label ); + prodLabels->PushBack( plab ); } - prodLabels->PushBack( 0 ); + PLabelType plab; + plab.SetSize(1); + plab[0] = 0; + prodLabels->PushBack( plab ); calculator->SetReferenceLabels( refLabels ); calculator->SetProducedLabels( prodLabels ); @@ -118,14 +139,18 @@ int otbConfusionMatrixCalculatorUpdate(int argc, char* argv[]) std::cerr << "Usage: " << argv[0] << " nbSamples nbClasses " << std::endl; return EXIT_FAILURE; } - typedef itk::FixedArray<int, 1> LabelType; - typedef itk::Statistics::ListSample<LabelType> ListLabelType; - typedef otb::ConfusionMatrixCalculator< ListLabelType > CalculatorType; + + typedef itk::VariableLengthVector<int> PLabelType; + typedef itk::Statistics::ListSample<PLabelType> PListLabelType; + typedef itk::FixedArray<int, 1> RLabelType; + typedef itk::Statistics::ListSample<RLabelType> RListLabelType; + typedef otb::ConfusionMatrixCalculator< RListLabelType, + PListLabelType > CalculatorType; CalculatorType::Pointer calculator = CalculatorType::New(); - ListLabelType::Pointer refLabels = ListLabelType::New(); - ListLabelType::Pointer prodLabels = ListLabelType::New(); + RListLabelType::Pointer refLabels = RListLabelType::New(); + PListLabelType::Pointer prodLabels = PListLabelType::New(); int nbSamples = atoi(argv[1]); int nbClasses = atoi(argv[2]); @@ -134,11 +159,14 @@ int otbConfusionMatrixCalculatorUpdate(int argc, char* argv[]) for(int i=0; i<nbSamples; i++) { int label = (i%nbClasses)+1; + PLabelType plab; + plab.SetSize(1); + plab[0] = label; refLabels->PushBack( label ); - prodLabels->PushBack( label ); + prodLabels->PushBack( plab ); } - + calculator->SetReferenceLabels( refLabels ); calculator->SetProducedLabels( prodLabels );