Skip to content
Snippets Groups Projects
Commit ae2828ea authored by Charles Peyrega's avatar Charles Peyrega
Browse files

ENH: Removal of the ValidateImagesClassifier OTB application and of the corresponding tests

parent 89bddd10
No related branches found
No related tags found
No related merge requests found
......@@ -33,7 +33,4 @@ IF(OTB_USE_OPENCV)
SOURCES otbImageClassifier.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters;OTBFeatureExtraction;OTBLearning;OTBMachineLearning)
OTB_CREATE_APPLICATION(NAME ValidateImagesClassifier
SOURCES otbValidateImagesClassifier.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters;OTBFeatureExtraction;OTBLearning;OTBMachineLearning)
ENDIF()
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include <fstream>
// ListSample
#include "itkListSample.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "otbListSampleGenerator.h"
// Model estimator
#include "otbMachineLearningModelFactory.h"
#include "otbMachineLearningModel.h"
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
// Validation
#include "otbConfusionMatrixCalculator.h"
// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include "otbConcatenateSampleListFilter.h"
// Extract a ROI of the vectordata
#include "otbVectorDataIntoImageProjectionFilter.h"
// Elevation handler
#include "otbWrapperElevationParametersHandler.h"
namespace otb
{
namespace Wrapper
{
class ValidateImagesClassifier: public Application
{
public:
/** Standard class typedefs. */
typedef ValidateImagesClassifier Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(ValidateImagesClassifier, otb::Application);
typedef otb::Image<FloatVectorImageType::InternalPixelType, 2> ImageReaderType;
typedef FloatVectorImageType::PixelType PixelType;
typedef FloatVectorImageType VectorImageType;
typedef FloatImageType ImageType;
typedef Int32ImageType LabeledImageType;
// Training vectordata
typedef itk::VariableLengthVector<ImageType::PixelType> MeasurementType;
// SampleList manipulation
typedef otb::ListSampleGenerator<VectorImageType, VectorDataType> ListSampleGeneratorType;
typedef ListSampleGeneratorType::ListSampleType ListSampleType;
typedef ListSampleGeneratorType::LabelType LabelType;
typedef ListSampleGeneratorType::ListLabelType LabelListSampleType;
typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
typedef otb::Statistics::ConcatenateSampleListFilter<LabelListSampleType> ConcatenateLabelListSampleFilterType;
// Statistic XML file Reader
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
// Enhance List Sample
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
/// Classification typedefs
typedef otb::MachineLearningModelFactory<FloatVectorImageType::InternalPixelType, ListSampleGeneratorType::ClassLabelType> MachineLearningModelFactoryType;
typedef otb::MachineLearningModel<FloatVectorImageType::InternalPixelType, ListSampleGeneratorType::ClassLabelType> ModelType;
// Estimate performance on validation sample
typedef otb::ConfusionMatrixCalculator<LabelListSampleType, LabelListSampleType> ConfusionMatrixCalculatorType;
// Extract ROI and Project vectorData
typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, VectorImageType> VectorDataReprojectionType;
private:
void DoInit()
{
SetName("ValidateImagesClassifier");
SetDescription("Estimate the performance of the model with a set of images and validation samples.");
SetDocName("Validate Images Classifier");
SetDocLongDescription("Estimate the performance of the model obtained by the TrainImagesClassifier with a new set of images and validation samples.\n The application asks for images statistics as input (XML file generated with the ComputeImagesStatistics application) and a SVM model (text file) generated with the TrainSVMImagesClassifier application.\n It will compute the global confusion matrix, kappa index and also the precision, recall and F-score of each class. In the validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels.");
SetDocLimitations("None");
SetDocAuthors("OTB-Team");
SetDocSeeAlso(" ");
AddDocTag(Tags::Learning);
AddParameter(ParameterType_InputImageList, "il", "Input Image List");
SetParameterDescription("il", "Input image list filename.");
AddParameter(ParameterType_InputVectorDataList, "vd", "Vector Data List");
SetParameterDescription("vd", "List of vector data to select validation samples.");
AddParameter(ParameterType_String, "vfn", "Name of the discrimination field");
SetParameterDescription("vfn", "Name of the field used to discriminate class in the vector data files.");
SetParameterString("vfn", "Class");
MandatoryOff("vfn");
AddParameter(ParameterType_InputFilename, "imstat", "XML image statistics file");
MandatoryOff("imstat");
SetParameterDescription("imstat", "Filename of an XML file containing mean and standard deviation of input images.");
// Elevation
ElevationParametersHandler::AddElevationParameters(this, "elev");
AddParameter(ParameterType_OutputFilename, "out", "Output filename");
SetParameterDescription("out", "Output file, which contains the performances of the SVM model.");
MandatoryOff("out");
AddParameter(ParameterType_InputFilename, "model", "Model filename");
SetParameterDescription("model",
"Input model to validate (given by TrainImagesClassification for instance).");
AddRANDParameter();
// Doc example parameter settings
SetDocExampleParameterValue("il", "QB_1_ortho.tif");
SetDocExampleParameterValue("vd", "VectorData_QB1.shp");
SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
SetDocExampleParameterValue("model", "clsvmModelQB1.svm");
SetDocExampleParameterValue("out", "PerformanceEstimationQB1.txt");
}
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
}
std::string LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc)
{
ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix();
// Compute minimal width
size_t minwidth = 0;
for (unsigned int i = 0; i < matrix.Rows(); i++)
{
for (unsigned int j = 0; j < matrix.Cols(); j++)
{
std::ostringstream os;
os << matrix(i,j);
size_t size = os.str().size();
if (size > minwidth)
{
minwidth = size;
}
}
}
typedef std::map<int, ConfusionMatrixCalculatorType::ClassLabelType> MapOfIndicesType;
MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices();
MapOfIndicesType::const_iterator it = mapOfIndices.begin();
MapOfIndicesType::const_iterator end = mapOfIndices.end();
for(; it != end; ++it)
{
std::ostringstream os;
os << "[" << it->second << "]";
size_t size = os.str().size();
if (size > minwidth)
{
minwidth = size;
}
}
// Generate matrix string, with 'minwidth' as size specifier
std::ostringstream os;
// Header line
for (size_t i = 0; i < minwidth; ++i)
os << " ";
os << " ";
it = mapOfIndices.begin();
end = mapOfIndices.end();
for(; it != end; ++it)
{
os << "[" << it->second << "]" << " ";
}
os << std::endl;
// Each line of confusion matrix
for (unsigned int i = 0; i < matrix.Rows(); i++)
{
ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i];
os << "[" << std::setw(minwidth - 2) << label << "]" << " ";
for (unsigned int j = 0; j < matrix.Cols(); j++)
{
os << std::setw(minwidth) << matrix(i,j) << " ";
}
os << std::endl;
}
otbAppLogINFO("Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str());
return os.str();
}
void DoExecute()
{
GetLogger()->Debug("Entering DoExecute\n");
//Create training and validation for list samples and label list samples
ConcatenateLabelListSampleFilterType::Pointer
concatenateTrainingLabels = ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateTrainingSamples = ConcatenateListSampleFilterType::New();
ConcatenateLabelListSampleFilterType::Pointer
concatenateValidationLabels = ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateValidationSamples = ConcatenateListSampleFilterType::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
//--------------------------
// Load measurements from images
unsigned int nbBands = 0;
//Iterate over all input images
FloatVectorImageListType* imageList = GetParameterImageList("il");
VectorDataListType* vectorDataList = GetParameterVectorDataList("vd");
//Iterate over all input images
for (unsigned int imgIndex = 0; imgIndex < imageList->Size(); ++imgIndex)
{
FloatVectorImageType::Pointer image = imageList->GetNthElement(imgIndex);
image->UpdateOutputInformation();
if (imgIndex == 0)
{
nbBands = image->GetNumberOfComponentsPerPixel();
}
// read the Vectordata
VectorDataType::Pointer vectorData = vectorDataList->GetNthElement(imgIndex);
vectorData->Update();
VectorDataReprojectionType::Pointer vdreproj = VectorDataReprojectionType::New();
vdreproj->SetInputImage(image);
vdreproj->SetInput(vectorData);
vdreproj->SetUseOutputSpacingAndOriginFromImage(false);
// Setup the DEM Handler
otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev");
vdreproj->Update();
//Sample list generator
ListSampleGeneratorType::Pointer sampleGenerator = ListSampleGeneratorType::New();
//Set inputs of the sample generator
//TODO the ListSampleGenerator perform UpdateOutputData over the input image (need a persistent implementation)
sampleGenerator->SetInput(image);
sampleGenerator->SetInputVectorData(vdreproj->GetOutput());
sampleGenerator->SetValidationTrainingProportion(1.0); // All in validation
sampleGenerator->SetClassKey(GetParameterString("vfn"));
sampleGenerator->Update();
//Concatenate training and validation samples from the image
concatenateValidationLabels->AddInput(sampleGenerator->GetValidationListLabel());
concatenateValidationSamples->AddInput(sampleGenerator->GetValidationListSample());
}
// Update
concatenateValidationSamples->Update();
concatenateValidationLabels->Update();
if (IsParameterEnabled("imstat"))
{
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
statisticsReader->SetFileName(GetParameterString("imstat"));
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
}
else
{
meanMeasurementVector.SetSize(nbBands);
meanMeasurementVector.Fill(0.);
stddevMeasurementVector.SetSize(nbBands);
stddevMeasurementVector.Fill(1.);
}
ShiftScaleFilterType::Pointer validationShiftScaleFilter = ShiftScaleFilterType::New();
validationShiftScaleFilter->SetInput(concatenateValidationSamples->GetOutput());
validationShiftScaleFilter->SetShifts(meanMeasurementVector);
validationShiftScaleFilter->SetScales(stddevMeasurementVector);
validationShiftScaleFilter->Update();
//--------------------------
// split the data set into training/validation set
ListSampleType::Pointer validationListSample = validationShiftScaleFilter->GetOutputSampleList();
LabelListSampleType::Pointer validationLabeledListSample = concatenateValidationLabels->GetOutputSampleList();
otbAppLogINFO("Size of validation set: " << validationListSample->Size());
otbAppLogINFO("Size of labeled validation set: " << validationLabeledListSample->Size());
//--------------------------
// Load svm model
ModelType::Pointer model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
MachineLearningModelFactoryType::ReadMode);
model->Load(GetParameterString("model"));
//--------------------------
// Performances estimation
LabelListSampleType::Pointer classifierListLabel = LabelListSampleType::New();
model->SetInputListSample(validationListSample);
model->SetTargetListSample(classifierListLabel);
model->PredictAll();
// Estimate performances
ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
confMatCalc->SetReferenceLabels(validationLabeledListSample);
confMatCalc->SetProducedLabels(classifierListLabel);
confMatCalc->Compute();
otbAppLogINFO("Model training performances");
std::string confMatString;
confMatString = LogConfusionMatrix(confMatCalc);
for (unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++)
{
otbAppLogINFO("Precision of class [" << itClasses << "] vs all: " << confMatCalc->GetPrecisions()[itClasses] << std::endl);
otbAppLogINFO("Recall of class [" << itClasses << "] vs all: " << confMatCalc->GetRecalls()[itClasses] << std::endl);
otbAppLogINFO("F-score of class [" << itClasses << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n" << std::endl);
}
otbAppLogINFO("Global performance, Kappa index: " << confMatCalc->GetKappaIndex() << std::endl);
otbAppLogINFO("Global performance, Overall accuracy: " << confMatCalc->GetOverallAccuracy() << std::endl);
//--------------------------
// Save output in a ascii file (if needed)
if (IsParameterEnabled("out"))
{
std::ofstream file;
file.open(GetParameterString("out").c_str());
file << "Confusion matrix (rows = reference labels, columns = produced labels):\n" << std::endl;
file << confMatString << std::endl;
file << "Precision of the different class: " << confMatCalc->GetPrecisions() << std::endl;
file << "Recall of the different class: " << confMatCalc->GetRecalls() << std::endl;
file << "F-score of the different class: " << confMatCalc->GetFScores() << std::endl;
file << "Kappa index: " << confMatCalc->GetKappaIndex() << std::endl;
file << "Overall accuracy index: " << confMatCalc->GetOverallAccuracy() << std::endl;
file.close();
}
}
};
} // end of namespace Wrapper
} // end of namespace otb
OTB_APPLICATION_EXPORT(otb::Wrapper::ValidateImagesClassifier)
......@@ -47,21 +47,6 @@ IF(OTB_USE_OPENCV)
SET_TESTS_PROPERTIES(apTvClTrainSVMImagesClassifierQB1_allOpt PROPERTIES DEPENDS apTvClComputeImagesStatisticsQB1)
OTB_TEST_APPLICATION(NAME apTvClValidateSVMImagesClassifierQB1
APP ValidateImagesClassifier
OPTIONS -il ${INPUTDATA}/Classification/QB_1_ortho.tif
-vd ${INPUTDATA}/Classification/VectorData_QB1_bis.shp
-imstat ${TEMP}/apTvClEstimateImageStatisticsQB1.xml
-model ${TEMP}/clsvmModelQB1.svm
-out ${TEMP}/clsvmPerformanceEstimationQB1.txt
-rand 121212
VALID --compare-ascii ${NOTOL}
${OTBAPP_BASELINE_FILES}/clsvmPerformanceEstimationQB1.txt
${TEMP}/clsvmPerformanceEstimationQB1.txt)
SET_TESTS_PROPERTIES(apTvClValidateSVMImagesClassifierQB1 PROPERTIES DEPENDS apTvClTrainSVMImagesClassifierQB1)
OTB_TEST_APPLICATION(NAME apTvClImageSVMClassifierQB1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho.tif
......@@ -137,21 +122,6 @@ IF(OTB_USE_OPENCV)
SET_TESTS_PROPERTIES(apTvClTrainSVMImagesClassifierQB123 PROPERTIES DEPENDS apTvClComputeImagesStatisticsQB123)
OTB_TEST_APPLICATION(NAME apTvClValidateSVMImagesClassifierQB123
APP ValidateImagesClassifier
OPTIONS -il ${INPUTDATA}/Classification/QB_1_ortho.tif
-vd ${INPUTDATA}/Classification/VectorData_QB1_bis.shp
-imstat ${TEMP}/apTvClEstimateImageStatisticsQB123.xml
-model ${TEMP}/clsvmModelQB123.svm
-out ${TEMP}/clsvmPerformanceEstimationQB123.txt
-rand 121212
VALID --compare-ascii ${NOTOL}
${OTBAPP_BASELINE_FILES}/clsvmPerformanceEstimationQB123.txt
${TEMP}/clsvmPerformanceEstimationQB123.txt)
SET_TESTS_PROPERTIES(apTvClValidateSVMImagesClassifierQB123 PROPERTIES DEPENDS apTvClTrainSVMImagesClassifierQB123)
OTB_TEST_APPLICATION(NAME apTvClImageSVMClassifierQB123_1
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_1_ortho.tif
......@@ -252,21 +222,6 @@ IF(OTB_USE_OPENCV)
SET_TESTS_PROPERTIES(apTvClTrainSVMImagesClassifierQB456 PROPERTIES DEPENDS apTvClComputeImagesStatisticsQB456)
OTB_TEST_APPLICATION(NAME apTvClValidateSVMImagesClassifierQB456
APP ValidateImagesClassifier
OPTIONS -il ${INPUTDATA}/Classification/QB_5_extract.tif
-vd ${INPUTDATA}/Classification/VectorData_QB5_bis.shp
-imstat ${TEMP}/apTvClEstimateImageStatisticsQB456.xml
-model ${TEMP}/clsvmModelQB456.svm
-rand 121212
-out ${TEMP}/clsvmPerformanceEstimationQB456.txt
VALID --compare-ascii ${NOTOL}
${OTBAPP_BASELINE_FILES}/clsvmPerformanceEstimationQB456.txt
${TEMP}/clsvmPerformanceEstimationQB456.txt)
SET_TESTS_PROPERTIES(apTvClValidateSVMImagesClassifierQB456 PROPERTIES DEPENDS apTvClTrainSVMImagesClassifierQB456)
OTB_TEST_APPLICATION(NAME apTvClImageSVMClassifierQB456_4
APP ImageClassifier
OPTIONS -in ${INPUTDATA}/Classification/QB_4_extract.tif
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment