Skip to content
Snippets Groups Projects
Commit b30c6050 authored by Jonathan Guinet's avatar Jonathan Guinet
Browse files

MRG

parents c9e79ef6 f1570e6f
No related branches found
No related tags found
No related merge requests found
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include <fstream>
//Image
#include "otbListSampleGenerator.h"
// ListSample
#include "itkListSample.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
// SVM estimator
#include "otbSVMSampleListModelEstimator.h"
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
// Validation
#include "otbSVMClassifier.h"
#include "otbConfusionMatrixCalculator.h"
// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include "otbConcatenateSampleListFilter.h"
// Classification filter
#include "otbSVMImageClassificationFilter.h"
// Extract a ROI of the vectordata
#include "otbVectorDataIntoImageProjectionFilter.h"
namespace otb
{
namespace Wrapper
{
class ValidateSVMImagesClassifier: public Application
{
public:
/** Standard class typedefs. */
typedef ValidateSVMImagesClassifier Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self)
;
itkTypeMacro(ValidateSVMImagesClassifier, otb::Application)
;
typedef otb::Image<FloatVectorImageType::InternalPixelType, 2> ImageReaderType;
typedef FloatVectorImageType::PixelType PixelType;
typedef FloatVectorImageType VectorImageType;
typedef FloatImageType ImageType;
typedef Int32ImageType LabeledImageType;
// Training vectordata
typedef itk::VariableLengthVector<ImageType::PixelType> MeasurementType;
// SampleList manipulation
typedef otb::ListSampleGenerator<VectorImageType, VectorDataType> ListSampleGeneratorType;
typedef ListSampleGeneratorType::ListSampleType ListSampleType;
typedef ListSampleGeneratorType::LabelType LabelType;
typedef ListSampleGeneratorType::ListLabelType LabelListSampleType;
typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
typedef otb::Statistics::ConcatenateSampleListFilter<LabelListSampleType> ConcatenateLabelListSampleFilterType;
// Statistic XML file Reader
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
// Enhance List Sample
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
/// Classification typedefs
typedef otb::SVMImageClassificationFilter<VectorImageType, LabeledImageType> ClassificationFilterType;
typedef ClassificationFilterType::Pointer ClassificationFilterPointerType;
typedef ClassificationFilterType::ModelType ModelType;
typedef ModelType::Pointer ModelPointerType;
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<MeasurementType> MeasurementVectorFunctorType;
typedef otb::SVMSampleListModelEstimator<ListSampleType, LabelListSampleType, MeasurementVectorFunctorType>
SVMEstimatorType;
// Estimate performance on validation sample
typedef otb::SVMClassifier<ListSampleType, LabelType::ValueType> ClassifierType;
typedef otb::ConfusionMatrixCalculator<LabelListSampleType, LabelListSampleType> ConfusionMatrixCalculatorType;
typedef ClassifierType::OutputType ClassifierOutputType;
// Extract ROI and Project vectorData
typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, VectorImageType> VectorDataReprojectionType;
private:
ValidateSVMImagesClassifier()
{
SetName("ValidateSVMImagesClassifier");
SetDescription("Perform SVM validation from multiple input images and multiple vector data.");
}
virtual ~ValidateSVMImagesClassifier()
{
}
void DoCreateParameters()
{
AddParameter(ParameterType_InputImageList, "il", "Input Image List");
AddParameter(ParameterType_InputVectorDataList, "vd", "Vector Data of sample used to validate the estimator");
AddParameter(ParameterType_Filename, "dem", "A DEM repository");
MandatoryOff("dem");
AddParameter(ParameterType_Filename, "imstat", "XML file containing mean and standard deviation of input images.");
MandatoryOff("imstat");
AddParameter(ParameterType_Filename, "svm", "SVM model to validate its performances.");
AddParameter(ParameterType_Filename, "out", "File which will contain the performance of the SVM model.");
}
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
}
void DoExecute()
{
GetLogger()->Debug("Entering DoExecute\n");
//Create training and validation for list samples and label list samples
ConcatenateLabelListSampleFilterType::Pointer
concatenateTrainingLabels = ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateTrainingSamples = ConcatenateListSampleFilterType::New();
ConcatenateLabelListSampleFilterType::Pointer
concatenateValidationLabels = ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateValidationSamples = ConcatenateListSampleFilterType::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
//--------------------------
// Load measurements from images
unsigned int nbBands = 0;
//Iterate over all input images
FloatVectorImageListType* imageList = GetParameterImageList("il");
VectorDataListType* vectorDataList = GetParameterVectorDataList("vd");
//Iterate over all input images
for (unsigned int imgIndex = 0; imgIndex < imageList->Size(); ++imgIndex)
{
FloatVectorImageType::Pointer image = imageList->GetNthElement(imgIndex);
image->UpdateOutputInformation();
if (imgIndex == 0)
{
nbBands = image->GetNumberOfComponentsPerPixel();
}
// read the Vectordata
VectorDataType::Pointer vectorData = vectorDataList->GetNthElement(imgIndex);
vectorData->Update();
VectorDataReprojectionType::Pointer vdreproj = VectorDataReprojectionType::New();
vdreproj->SetInputImage(image);
vdreproj->SetInput(vectorData);
vdreproj->SetUseOutputSpacingAndOriginFromImage(false);
// Configure DEM directory
if (HasUserValue("dem"))
{
vdreproj->SetDEMDirectory(GetParameterString("dem"));
}
else
{
if (otb::ConfigurationFile::GetInstance()->IsValid())
{
vdreproj->SetDEMDirectory(otb::ConfigurationFile::GetInstance()->GetDEMDirectory());
}
}
vdreproj->Update();
//Sample list generator
ListSampleGeneratorType::Pointer sampleGenerator = ListSampleGeneratorType::New();
//Set inputs of the sample generator
//TODO the ListSampleGenerator perform UpdateOutputData over the input image (need a persistent implementation)
sampleGenerator->SetInput(image);
sampleGenerator->SetInputVectorData(vdreproj->GetOutput());
sampleGenerator->SetValidationTrainingProportion(1.0); // All in validation
sampleGenerator->SetClassKey("Class");
sampleGenerator->Update();
//Concatenate training and validation samples from the image
concatenateValidationLabels->AddInput(sampleGenerator->GetValidationListLabel());
concatenateValidationSamples->AddInput(sampleGenerator->GetValidationListSample());
}
// Update
concatenateValidationSamples->Update();
concatenateValidationLabels->Update();
if (HasValue("imstat"))
{
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
statisticsReader->SetFileName(GetParameterString("imstat"));
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
}
else
{
meanMeasurementVector.SetSize(nbBands);
meanMeasurementVector.Fill(0.);
stddevMeasurementVector.SetSize(nbBands);
stddevMeasurementVector.Fill(1.);
}
ShiftScaleFilterType::Pointer validationShiftScaleFilter = ShiftScaleFilterType::New();
validationShiftScaleFilter->SetInput(concatenateValidationSamples->GetOutput());
validationShiftScaleFilter->SetShifts(meanMeasurementVector);
validationShiftScaleFilter->SetScales(stddevMeasurementVector);
validationShiftScaleFilter->Update();
//--------------------------
// split the data set into training/validation set
ListSampleType::Pointer validationListSample = validationShiftScaleFilter->GetOutputSampleList();
LabelListSampleType::Pointer validationLabeledListSample = concatenateValidationLabels->GetOutputSampleList();
otbAppLogINFO("Size of validation set: " << validationListSample->Size());
otbAppLogINFO("Size of labeled validation set: " << validationLabeledListSample->Size());
//--------------------------
// Load svm model
ModelPointerType modelSVM = ModelType::New();
modelSVM->LoadModel(GetParameterString("svm").c_str());
//--------------------------
// Performances estimation
ClassifierType::Pointer validationClassifier = ClassifierType::New();
validationClassifier->SetSample(validationListSample);
validationClassifier->SetNumberOfClasses(modelSVM->GetNumberOfClasses());
validationClassifier->SetModel(modelSVM);
validationClassifier->Update();
// Estimate performances
ClassifierOutputType::ConstIterator it = validationClassifier->GetOutput()->Begin();
ClassifierOutputType::ConstIterator itEnd = validationClassifier->GetOutput()->End();
LabelListSampleType::Pointer classifierListLabel = LabelListSampleType::New();
while (it != itEnd)
{
// Due to a bug in SVMClassifier, outlier in one-class SVM are labeled with unsigned int max
classifierListLabel->PushBack(
it.GetClassLabel() == itk::NumericTraits<unsigned int>::max() ? 2
: it.GetClassLabel());
++it;
}
ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
confMatCalc->SetReferenceLabels(validationLabeledListSample);
confMatCalc->SetProducedLabels(classifierListLabel);
confMatCalc->Update();
otbAppLogINFO("*** SVM training performances ***\n" <<"Confusion matrix:\n" << confMatCalc->GetConfusionMatrix() << std::endl);
for (unsigned int itClasses = 0; itClasses < modelSVM->GetNumberOfClasses(); itClasses++)
{
otbAppLogINFO("Precision of class [" << itClasses << "] vs all: " << confMatCalc->GetPrecisions()[itClasses] << std::endl);
otbAppLogINFO("Recall of class [" << itClasses << "] vs all: " << confMatCalc->GetRecalls()[itClasses] << std::endl);
otbAppLogINFO("F-score of class [" << itClasses << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n" << std::endl);
}
otbAppLogINFO("Global performance, Kappa index: " << confMatCalc->GetKappaIndex() << std::endl);
//--------------------------
// Save output in a ascii file (if needed)
if (IsParameterEnabled("out"))
{
std::ofstream file;
file.open(GetParameterString("out").c_str());
file << "Precision of the different class: " << confMatCalc->GetPrecisions() << std::endl;
file << "Recall of the different class: " << confMatCalc->GetRecalls() << std::endl;
file << "F-score of the different class: " << confMatCalc->GetFScores() << std::endl;
file << "Kappa index: " << confMatCalc->GetKappaIndex() << std::endl;
file.close();
}
}
};
} // end of namespace Wrapper
} // end of namespace otb
OTB_APPLICATION_EXPORT(otb::Wrapper::ValidateSVMImagesClassifier)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment