Skip to content
Snippets Groups Projects
Forked from Main Repositories / otb
9233 commits behind the upstream repository.
otbTrainVectorClassifier.cxx 9.94 KiB
/*=========================================================================
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
#include "otbTrainVectorBase.h"

// Validation
#include "otbConfusionMatrixCalculator.h"
#include "otbContingencyTableCalculator.h"

namespace otb
{
namespace Wrapper
{

class TrainVectorClassifier : public TrainVectorBase
{
public:
  typedef TrainVectorClassifier Self;
  typedef TrainVectorBase Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
  itkNewMacro( Self )

  itkTypeMacro( Self, Superclass )

  typedef Superclass::SampleType SampleType;
  typedef Superclass::ListSampleType ListSampleType;
  typedef Superclass::TargetListSampleType TargetListSampleType;

  // Estimate performance on validation sample
  typedef otb::ConfusionMatrixCalculator<TargetListSampleType, TargetListSampleType> ConfusionMatrixCalculatorType;
  typedef ConfusionMatrixCalculatorType::ConfusionMatrixType ConfusionMatrixType;
  typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType;
  typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType;

private:
  void DoTrainInit()
  {
    // Nothing to do here
  }

  void DoTrainUpdateParameters()
  {
    // Nothing to do here
  }

  void DoBeforeTrainExecute()
  {
    // Enforce the need of class field name in supervised mode
    if (GetClassifierCategory() == Supervised)
      {
      featuresInfo.SetClassFieldNames( GetChoiceNames( "cfield" ), GetSelectedItems( "cfield" ) );

      if( featuresInfo.m_SelectedCFieldIdx.empty() )
        {
        otbAppLogFATAL( << "No field has been selected for data labelling!" );
        }
      }
  }

  void DoAfterTrainExecute()
  {

    if (GetClassifierCategory() == Supervised)
      {
      ConfusionMatrixCalculatorType::Pointer confMatCalc = ComputeConfusionMatrix( predictedList,
                                                                                   classificationListSamples.labeledListSample );
      WriteConfusionMatrix( confMatCalc );
      }
    else
      {
      WriteContingencyTable();
      }
  }


  void WriteContingencyTable()
  {
    // Compute contingency table
    typedef ContingencyTableCalculator<ClassLabelType> ContigencyTableCalcutaltorType;
    ContigencyTableCalcutaltorType::Pointer contingencyTableCalculator = ContigencyTableCalcutaltorType::New();
    contingencyTableCalculator->Compute(predictedList->Begin(), predictedList->End(),
                                        classificationListSamples.labeledListSample->Begin(),
                                        classificationListSamples.labeledListSample->End());
    ContingencyTable<ClassLabelType> contingencyTable = contingencyTableCalculator->GetContingencyTable();

    // Write contingency table
    std::ofstream outFile;
    outFile.open( this->GetParameterString( "io.confmatout" ).c_str() );
    outFile << contingencyTable.to_csv();
  }


  ConfusionMatrixCalculatorType::Pointer
  ComputeConfusionMatrix(const TargetListSampleType::Pointer &predictedListSample,
                         const TargetListSampleType::Pointer &performanceLabeledListSample)
  {
    ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();

    otbAppLogINFO( "Predicted list size : " << predictedListSample->Size() );
    otbAppLogINFO( "ValidationLabeledListSample size : " << performanceLabeledListSample->Size() );
    confMatCalc->SetReferenceLabels( performanceLabeledListSample );
    confMatCalc->SetProducedLabels( predictedListSample );
    confMatCalc->Compute();

    otbAppLogINFO( "training performances" );
    LogConfusionMatrix( confMatCalc );

    for( unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++ )
      {
      ConfusionMatrixCalculatorType::ClassLabelType classLabel = confMatCalc->GetMapOfIndices()[itClasses];

      otbAppLogINFO( "Precision of class [" << classLabel << "] vs all: " << confMatCalc->GetPrecisions()[itClasses] );
      otbAppLogINFO( "Recall of class    [" << classLabel << "] vs all: " << confMatCalc->GetRecalls()[itClasses] );
      otbAppLogINFO(
              "F-score of class   [" << classLabel << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n" );
      }
    otbAppLogINFO( "Global performance, Kappa index: " << confMatCalc->GetKappaIndex() );
    return confMatCalc;
  }

  /**
   * Write the confidence matrix into a file if output is provided.
   * \param confMatCalc the input matrix to write.
   */
  void WriteConfusionMatrix(const ConfusionMatrixCalculatorType::Pointer &confMatCalc)
  {
    if( this->HasValue( "io.confmatout" ) )
      {
      // Writing the confusion matrix in the output .CSV file

      MapOfIndicesType::iterator itMapOfIndicesValid, itMapOfIndicesPred;
      ClassLabelType labelValid = 0;

      ConfusionMatrixType confusionMatrix = confMatCalc->GetConfusionMatrix();
      MapOfIndicesType mapOfIndicesValid = confMatCalc->GetMapOfIndices();

      unsigned long nbClassesPred = mapOfIndicesValid.size();

      /////////////////////////////////////////////
      // Filling the 2 headers for the output file
      const std::string commentValidStr = "#Reference labels (rows):";
      const std::string commentPredStr = "#Produced labels (columns):";
      const char separatorChar = ',';
      std::ostringstream ossHeaderValidLabels, ossHeaderPredLabels;

      // Filling ossHeaderValidLabels and ossHeaderPredLabels for the output file
      ossHeaderValidLabels << commentValidStr;
      ossHeaderPredLabels << commentPredStr;

      itMapOfIndicesValid = mapOfIndicesValid.begin();

      while( itMapOfIndicesValid != mapOfIndicesValid.end() )
        {
        // labels labelValid of mapOfIndicesValid are already sorted in otbConfusionMatrixCalculator
        labelValid = itMapOfIndicesValid->second;

        otbAppLogINFO( "mapOfIndicesValid[" << itMapOfIndicesValid->first << "] = " << labelValid );

        ossHeaderValidLabels << labelValid;
        ossHeaderPredLabels << labelValid;

        ++itMapOfIndicesValid;

        if( itMapOfIndicesValid != mapOfIndicesValid.end() )
          {
          ossHeaderValidLabels << separatorChar;
          ossHeaderPredLabels << separatorChar;
          }
        else
          {
          ossHeaderValidLabels << std::endl;
          ossHeaderPredLabels << std::endl;
          }
        }

      std::ofstream outFile;
      outFile.open( this->GetParameterString( "io.confmatout" ).c_str() );
      outFile << std::fixed;
      outFile.precision( 10 );

      /////////////////////////////////////
      // Writing the 2 headers
      outFile << ossHeaderValidLabels.str();
      outFile << ossHeaderPredLabels.str();
      /////////////////////////////////////

      unsigned int indexLabelValid = 0, indexLabelPred = 0;

      for( itMapOfIndicesValid = mapOfIndicesValid.begin();
           itMapOfIndicesValid != mapOfIndicesValid.end(); ++itMapOfIndicesValid )
        {
        indexLabelPred = 0;

        for( itMapOfIndicesPred = mapOfIndicesValid.begin();
             itMapOfIndicesPred != mapOfIndicesValid.end(); ++itMapOfIndicesPred )
          {
          // Writing the confusion matrix (sorted in otbConfusionMatrixCalculator) in the output file
          outFile << confusionMatrix( indexLabelValid, indexLabelPred );
          if( indexLabelPred < ( nbClassesPred - 1 ) )
            {
            outFile << separatorChar;
            }
          else
            {
            outFile << std::endl;
            }
          ++indexLabelPred;
          }

        ++indexLabelValid;
        }

      outFile.close();
      }
  }

  /**
   * Display the log of the confusion matrix computed with
   * \param confMatCalc the input confusion matrix to display
   */
  void LogConfusionMatrix(ConfusionMatrixCalculatorType *confMatCalc)
  {
    ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix();

    // Compute minimal width
    size_t minwidth = 0;

    for( unsigned int i = 0; i < matrix.Rows(); i++ )
      {
      for( unsigned int j = 0; j < matrix.Cols(); j++ )
        {
        std::ostringstream os;
        os << matrix( i, j );
        size_t size = os.str().size();

        if( size > minwidth )
          {
          minwidth = size;
          }
        }
      }

    MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices();

    MapOfIndicesType::const_iterator it = mapOfIndices.begin();
    MapOfIndicesType::const_iterator end = mapOfIndices.end();

    for( ; it != end; ++it )
      {
      std::ostringstream os;
      os << "[" << it->second << "]";

      size_t size = os.str().size();
      if( size > minwidth )
        {
        minwidth = size;
        }
      }

    // Generate matrix string, with 'minwidth' as size specifier
    std::ostringstream os;

    // Header line
    for( size_t i = 0; i < minwidth; ++i )
      os << " ";
    os << " ";

    it = mapOfIndices.begin();
    end = mapOfIndices.end();
    for( ; it != end; ++it )
      {
      os << "[" << it->second << "]" << " ";
      }

    os << std::endl;

    // Each line of confusion matrix
    for( unsigned int i = 0; i < matrix.Rows(); i++ )
      {
      ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i];
      os << "[" << std::setw( minwidth - 2 ) << label << "]" << " ";
      for( unsigned int j = 0; j < matrix.Cols(); j++ )
        {
        os << std::setw( minwidth ) << matrix( i, j ) << " ";
        }
      os << std::endl;
      }

    otbAppLogINFO( "Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str() );
  }

};
}
}

OTB_APPLICATION_EXPORT( otb::Wrapper::TrainVectorClassifier )