/*========================================================================= Program: ORFEO Toolbox Language: C++ Date: $Date$ Version: $Revision$ Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. See OTBCopyright.txt for details. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the above copyright notices for more information. =========================================================================*/ #include "otbWrapperApplication.h" #include "otbWrapperApplicationFactory.h" #include "otbOGRDataSourceToLabelImageFilter.h" #include "itkImageRegionConstIterator.h" #include "otbOGRDataSourceWrapper.h" #include "itkImageRegionSplitter.h" #include "otbStreamingTraits.h" namespace otb { namespace Wrapper { class ComputeConfusionMatrix : public Application { public: /** Standard class typedefs. */ typedef ComputeConfusionMatrix Self; typedef Application Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self); itkTypeMacro(ComputeConfusionMatrix, otb::Application); typedef itk::ImageRegionConstIterator<Int32ImageType> ImageIteratorType; typedef otb::OGRDataSourceToLabelImageFilter <Int32ImageType> RasterizeFilterType; typedef otb::StreamingTraits<Int32ImageType> StreamingTraitsType; typedef itk::ImageRegionSplitter<2> SplitterType; typedef Int32ImageType::RegionType RegionType; private: void DoInit() { SetName("ComputeConfusionMatrix"); SetDescription("Computes the confusion matrix of a classification"); // Documentation SetDocName("Compute Confusion Matrix Application"); SetDocLongDescription("This application computes the confusion matrix of a classification relatively to a ground truth. This ground truth can be given as a raster or a vector data."); SetDocLimitations("None"); SetDocAuthors("OTB-Team"); SetDocSeeAlso(" "); //AddDocTag(Tags::Classification); AddParameter(ParameterType_InputImage, "in", "Input Image"); SetParameterDescription( "in", "The input classification image." ); AddParameter(ParameterType_OutputFilename, "out", "Matrix output"); SetParameterDescription("out", "Filename to store the output matrix (csv format)"); AddParameter(ParameterType_Choice,"ref","Ground truth"); SetParameterDescription("ref","Choice of ground truth format"); AddChoice("ref.raster","Ground truth as a raster image"); AddChoice("ref.vector","Ground truth as a vector data file"); AddParameter(ParameterType_InputImage,"ref.raster.in","Input reference image"); SetParameterDescription("ref.raster.in","Input image containing the ground truth labels"); AddParameter(ParameterType_InputFilename,"ref.vector.in","Input reference vector data"); SetParameterDescription("ref.vector.in", "Input vector data of the ground truth"); AddParameter(ParameterType_String,"ref.vector.field","Field name"); SetParameterDescription("ref.vector.field","Field name containing the label values"); SetParameterString("ref.vector.field","dn"); MandatoryOff("ref.vector.field"); DisableParameter("ref.vector.field"); AddParameter(ParameterType_Int,"labels","Number of labels"); SetParameterDescription("labels","Number of labels in the classification. The label values shall be contiguous and start from 1."); SetDefaultParameterInt("labels",2); MandatoryOff("labels"); DisableParameter("labels"); AddParameter(ParameterType_Int,"nodata","Value for nodata pixels"); SetParameterDescription("nodata","This value will be used to discard pixels from the ground truth"); SetDefaultParameterInt("nodata",0); MandatoryOff("nodata"); DisableParameter("nodata"); AddRAMParameter(); // Doc example parameter settings SetDocExampleParameterValue("in", "clLabeledImageQB1.tif"); SetDocExampleParameterValue("out", "confusion.txt"); SetDocExampleParameterValue("ref", "vector"); SetDocExampleParameterValue("ref.vector.in","VectorData_QB1_bis.shp"); SetDocExampleParameterValue("ref.vector.field","Class"); SetDocExampleParameterValue("labels","4"); } void DoUpdateParameters() { // Nothing to do here : all parameters are independent } void DoExecute() { Int32ImageType* input = this->GetParameterImage<Int32ImageType>("in"); std::string field; int nodata = this->GetParameterInt("nodata"); //Init Conf Matrix unsigned int nbClasses = this->GetParameterInt("labels"); m_Matrix.resize(nbClasses); for(unsigned int i=0; i<nbClasses; i++ ) { m_Matrix[i].assign(nbClasses,0); } Int32ImageType::Pointer reference; otb::ogr::DataSource::Pointer ogrRef; RasterizeFilterType::Pointer rasterizeReference = RasterizeFilterType::New(); if (GetParameterString("ref") == "raster") { reference = this->GetParameterImage<Int32ImageType>("ref.raster.in"); } else { ogrRef = otb::ogr::DataSource::New(GetParameterString("ref.vector.in"), otb::ogr::DataSource::Modes::Read); field = this->GetParameterString("ref.vector.field"); rasterizeReference->AddOGRDataSource(ogrRef); rasterizeReference->SetOutputParametersFromImage(input); rasterizeReference->SetBackgroundValue(nodata); rasterizeReference->SetBurnAttribute(field.c_str()); reference = rasterizeReference->GetOutput(); reference->UpdateOutputInformation(); } // Prepare local streaming SplitterType::Pointer splitter = SplitterType::New(); unsigned int numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions( input, input->GetLargestPossibleRegion(), splitter, otb::SET_BUFFER_MEMORY_SIZE, 0, 1048576*GetParameterInt("ram"), 0); RegionType streamRegion; otbAppLogINFO("Number of stream divisions : "<<numberOfStreamDivisions); for (unsigned int index=0; index<numberOfStreamDivisions; index++) { streamRegion = splitter->GetSplit(index, numberOfStreamDivisions, input->GetLargestPossibleRegion()); input->SetRequestedRegion(streamRegion); input->PropagateRequestedRegion(); input->UpdateOutputData(); reference->SetRequestedRegion(streamRegion); reference->PropagateRequestedRegion(); reference->UpdateOutputData(); ImageIteratorType itInput(input, streamRegion); itInput.GoToBegin(); ImageIteratorType itRef(reference, streamRegion); itRef.GoToBegin(); while (!itInput.IsAtEnd()) { if (itRef.Get() != nodata) { if (itRef.Get()>0 && itRef.Get()<=nbClasses && itInput.Get()>0 && itInput.Get()<=nbClasses) { m_Matrix[itInput.Get()-1][itRef.Get()-1] ++; } } ++ itInput; ++ itRef; } } std::ofstream outFile; outFile.open(this->GetParameterString("out").c_str()); outFile<<std::fixed; outFile.precision(10); for(unsigned int j=0; j<nbClasses; j++ ) { for(unsigned int i=0; i<nbClasses; i++ ) { outFile << m_Matrix[i][j]; if (i<(nbClasses-1)) { outFile<<"\t"; } else { outFile<<std::endl; } } } outFile.close(); } std::vector<std::vector<unsigned long> > m_Matrix; }; } } OTB_APPLICATION_EXPORT(otb::Wrapper::ComputeConfusionMatrix)