Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include <fstream>
// ListSample
#include "itkListSample.h"
#include "itkVariableLengthVector.h"
#include "itkFixedArray.h"
#include "otbListSampleGenerator.h"
// Model estimator
#include "otbMachineLearningModelFactory.h"
#include "otbMachineLearningModel.h"
// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"
// Validation
#include "otbConfusionMatrixCalculator.h"
// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include "otbConcatenateSampleListFilter.h"
// Extract a ROI of the vectordata
#include "otbVectorDataIntoImageProjectionFilter.h"
// Elevation handler
#include "otbWrapperElevationParametersHandler.h"
namespace otb
{
namespace Wrapper
{
class ValidateImagesClassifier: public Application
{
public:
/** Standard class typedefs. */
typedef ValidateImagesClassifier Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(ValidateImagesClassifier, otb::Application);
typedef otb::Image<FloatVectorImageType::InternalPixelType, 2> ImageReaderType;
typedef FloatVectorImageType::PixelType PixelType;
typedef FloatVectorImageType VectorImageType;
typedef FloatImageType ImageType;
typedef Int32ImageType LabeledImageType;
// Training vectordata
typedef itk::VariableLengthVector<ImageType::PixelType> MeasurementType;
// SampleList manipulation
typedef otb::ListSampleGenerator<VectorImageType, VectorDataType> ListSampleGeneratorType;
typedef ListSampleGeneratorType::ListSampleType ListSampleType;
typedef ListSampleGeneratorType::LabelType LabelType;
typedef ListSampleGeneratorType::ListLabelType LabelListSampleType;
typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
typedef otb::Statistics::ConcatenateSampleListFilter<LabelListSampleType> ConcatenateLabelListSampleFilterType;
// Statistic XML file Reader
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
// Enhance List Sample
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
/// Classification typedefs
typedef otb::MachineLearningModelFactory<FloatVectorImageType::InternalPixelType, ListSampleGeneratorType::ClassLabelType> MachineLearningModelFactoryType;
typedef otb::MachineLearningModel<FloatVectorImageType::InternalPixelType, ListSampleGeneratorType::ClassLabelType> ModelType;
// Estimate performance on validation sample
typedef otb::ConfusionMatrixCalculator<LabelListSampleType, LabelListSampleType> ConfusionMatrixCalculatorType;
// Extract ROI and Project vectorData
typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, VectorImageType> VectorDataReprojectionType;
private:
void DoInit()
{
SetName("ValidateImagesClassifier");
SetDescription("Estimate the performance of the model with a set of images and validation samples.");
SetDocName("Validate Images Classifier");
SetDocLongDescription("Estimate the performance of the model obtained by the TrainImagesClassifier with a new set of images and validation samples.\n The application asks for images statistics as input (XML file generated with the ComputeImagesStatistics application) and a SVM model (text file) generated with the TrainSVMImagesClassifier application.\n It will compute the global confusion matrix, kappa index and also the precision, recall and F-score of each class. In the validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels.");
SetDocLimitations("None");
SetDocAuthors("OTB-Team");
SetDocSeeAlso(" ");
AddDocTag(Tags::Learning);
AddParameter(ParameterType_InputImageList, "il", "Input Image List");
SetParameterDescription("il", "Input image list filename.");
AddParameter(ParameterType_InputVectorDataList, "vd", "Vector Data List");
SetParameterDescription("vd", "List of vector data to select validation samples.");
Jonathan Guinet
committed
AddParameter(ParameterType_String, "vfn", "Name of the discrimination field");
SetParameterDescription("vfn", "Name of the field used to discriminate class in the vector data files.");
SetParameterString("vfn", "Class");
MandatoryOff("vfn");
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
AddParameter(ParameterType_InputFilename, "imstat", "XML image statistics file");
MandatoryOff("imstat");
SetParameterDescription("imstat", "Filename of an XML file containing mean and standard deviation of input images.");
// Elevation
ElevationParametersHandler::AddElevationParameters(this, "elev");
AddParameter(ParameterType_OutputFilename, "out", "Output filename");
SetParameterDescription("out", "Output file, which contains the performances of the SVM model.");
MandatoryOff("out");
AddParameter(ParameterType_InputFilename, "model", "Model filename");
SetParameterDescription("model",
"Input model to validate (given by TrainImagesClassification for instance).");
AddRANDParameter();
// Doc example parameter settings
SetDocExampleParameterValue("il", "QB_1_ortho.tif");
SetDocExampleParameterValue("vd", "VectorData_QB1.shp");
SetDocExampleParameterValue("imstat", "EstimateImageStatisticsQB1.xml");
SetDocExampleParameterValue("model", "clsvmModelQB1.svm");
SetDocExampleParameterValue("out", "PerformanceEstimationQB1.txt");
}
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
}
std::string LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc)
{
ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix();
// Compute minimal width
size_t minwidth = 0;
for (unsigned int i = 0; i < matrix.Rows(); i++)
{
for (unsigned int j = 0; j < matrix.Cols(); j++)
{
std::ostringstream os;
os << matrix(i,j);
size_t size = os.str().size();
if (size > minwidth)
{
minwidth = size;
}
}
}
typedef std::map<int, ConfusionMatrixCalculatorType::ClassLabelType> MapOfIndicesType;
MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices();
MapOfIndicesType::const_iterator it = mapOfIndices.begin();
MapOfIndicesType::const_iterator end = mapOfIndices.end();
for(; it != end; ++it)
{
std::ostringstream os;
os << "[" << it->second << "]";
size_t size = os.str().size();
if (size > minwidth)
{
minwidth = size;
}
}
// Generate matrix string, with 'minwidth' as size specifier
std::ostringstream os;
// Header line
for (size_t i = 0; i < minwidth; ++i)
os << " ";
os << " ";
it = mapOfIndices.begin();
end = mapOfIndices.end();
for(; it != end; ++it)
{
os << "[" << it->second << "]" << " ";
}
os << std::endl;
// Each line of confusion matrix
for (unsigned int i = 0; i < matrix.Rows(); i++)
{
ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i];
os << "[" << std::setw(minwidth - 2) << label << "]" << " ";
for (unsigned int j = 0; j < matrix.Cols(); j++)
{
os << std::setw(minwidth) << matrix(i,j) << " ";
}
os << std::endl;
}
otbAppLogINFO("Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str());
return os.str();
}
void DoExecute()
{
GetLogger()->Debug("Entering DoExecute\n");
//Create training and validation for list samples and label list samples
ConcatenateLabelListSampleFilterType::Pointer
concatenateTrainingLabels = ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateTrainingSamples = ConcatenateListSampleFilterType::New();
ConcatenateLabelListSampleFilterType::Pointer
concatenateValidationLabels = ConcatenateLabelListSampleFilterType::New();
ConcatenateListSampleFilterType::Pointer concatenateValidationSamples = ConcatenateListSampleFilterType::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
//--------------------------
// Load measurements from images
unsigned int nbBands = 0;
//Iterate over all input images
FloatVectorImageListType* imageList = GetParameterImageList("il");
VectorDataListType* vectorDataList = GetParameterVectorDataList("vd");
//Iterate over all input images
for (unsigned int imgIndex = 0; imgIndex < imageList->Size(); ++imgIndex)
{
FloatVectorImageType::Pointer image = imageList->GetNthElement(imgIndex);
image->UpdateOutputInformation();
if (imgIndex == 0)
{
nbBands = image->GetNumberOfComponentsPerPixel();
}
// read the Vectordata
VectorDataType::Pointer vectorData = vectorDataList->GetNthElement(imgIndex);
vectorData->Update();
VectorDataReprojectionType::Pointer vdreproj = VectorDataReprojectionType::New();
vdreproj->SetInputImage(image);
vdreproj->SetInput(vectorData);
vdreproj->SetUseOutputSpacingAndOriginFromImage(false);
// Setup the DEM Handler
otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev");
vdreproj->Update();
//Sample list generator
ListSampleGeneratorType::Pointer sampleGenerator = ListSampleGeneratorType::New();
//Set inputs of the sample generator
//TODO the ListSampleGenerator perform UpdateOutputData over the input image (need a persistent implementation)
sampleGenerator->SetInput(image);
sampleGenerator->SetInputVectorData(vdreproj->GetOutput());
sampleGenerator->SetValidationTrainingProportion(1.0); // All in validation
Jonathan Guinet
committed
sampleGenerator->SetClassKey(GetParameterString("vfn"));
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
sampleGenerator->Update();
//Concatenate training and validation samples from the image
concatenateValidationLabels->AddInput(sampleGenerator->GetValidationListLabel());
concatenateValidationSamples->AddInput(sampleGenerator->GetValidationListSample());
}
// Update
concatenateValidationSamples->Update();
concatenateValidationLabels->Update();
if (IsParameterEnabled("imstat"))
{
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
statisticsReader->SetFileName(GetParameterString("imstat"));
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
}
else
{
meanMeasurementVector.SetSize(nbBands);
meanMeasurementVector.Fill(0.);
stddevMeasurementVector.SetSize(nbBands);
stddevMeasurementVector.Fill(1.);
}
ShiftScaleFilterType::Pointer validationShiftScaleFilter = ShiftScaleFilterType::New();
validationShiftScaleFilter->SetInput(concatenateValidationSamples->GetOutput());
validationShiftScaleFilter->SetShifts(meanMeasurementVector);
validationShiftScaleFilter->SetScales(stddevMeasurementVector);
validationShiftScaleFilter->Update();
//--------------------------
// split the data set into training/validation set
ListSampleType::Pointer validationListSample = validationShiftScaleFilter->GetOutputSampleList();
LabelListSampleType::Pointer validationLabeledListSample = concatenateValidationLabels->GetOutputSampleList();
otbAppLogINFO("Size of validation set: " << validationListSample->Size());
otbAppLogINFO("Size of labeled validation set: " << validationLabeledListSample->Size());
//--------------------------
// Load svm model
ModelType::Pointer model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
MachineLearningModelFactoryType::ReadMode);
model->Load(GetParameterString("model"));
//--------------------------
// Performances estimation
LabelListSampleType::Pointer classifierListLabel = LabelListSampleType::New();
model->SetInputListSample(validationListSample);
model->SetTargetListSample(classifierListLabel);
model->PredictAll();
// Estimate performances
ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
confMatCalc->SetReferenceLabels(validationLabeledListSample);
confMatCalc->SetProducedLabels(classifierListLabel);
confMatCalc->Compute();
otbAppLogINFO("Model training performances");
std::string confMatString;
confMatString = LogConfusionMatrix(confMatCalc);
for (unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++)
{
otbAppLogINFO("Precision of class [" << itClasses << "] vs all: " << confMatCalc->GetPrecisions()[itClasses] << std::endl);
otbAppLogINFO("Recall of class [" << itClasses << "] vs all: " << confMatCalc->GetRecalls()[itClasses] << std::endl);
otbAppLogINFO("F-score of class [" << itClasses << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n" << std::endl);
}
otbAppLogINFO("Global performance, Kappa index: " << confMatCalc->GetKappaIndex() << std::endl);
otbAppLogINFO("Global performance, Overall accuracy: " << confMatCalc->GetOverallAccuracy() << std::endl);
//--------------------------
// Save output in a ascii file (if needed)
if (IsParameterEnabled("out"))
{
std::ofstream file;
file.open(GetParameterString("out").c_str());
file << "Confusion matrix (rows = reference labels, columns = produced labels):\n" << std::endl;
file << confMatString << std::endl;
file << "Precision of the different class: " << confMatCalc->GetPrecisions() << std::endl;
file << "Recall of the different class: " << confMatCalc->GetRecalls() << std::endl;
file << "F-score of the different class: " << confMatCalc->GetFScores() << std::endl;
file << "Kappa index: " << confMatCalc->GetKappaIndex() << std::endl;
file << "Overall accuracy index: " << confMatCalc->GetOverallAccuracy() << std::endl;
file.close();
}
}
};
} // end of namespace Wrapper
} // end of namespace otb
OTB_APPLICATION_EXPORT(otb::Wrapper::ValidateImagesClassifier)