Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
otb
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Main Repositories
otb
Commits
b30c6050
Commit
b30c6050
authored
13 years ago
by
Jonathan Guinet
Browse files
Options
Downloads
Plain Diff
MRG
parents
c9e79ef6
f1570e6f
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
Applications/Classification/otbValidateSVMImagesClassifier.cxx
+316
-0
316 additions, 0 deletions
...cations/Classification/otbValidateSVMImagesClassifier.cxx
with
316 additions
and
0 deletions
Applications/Classification/otbValidateSVMImagesClassifier.cxx
0 → 100644
+
316
−
0
View file @
b30c6050
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include
"otbWrapperApplication.h"
#include
"otbWrapperApplicationFactory.h"
#include
<fstream>
//Image
#include
"otbListSampleGenerator.h"
// ListSample
#include
"itkListSample.h"
#include
"itkVariableLengthVector.h"
#include
"itkFixedArray.h"
// SVM estimator
#include
"otbSVMSampleListModelEstimator.h"
// Statistic XML Reader
#include
"otbStatisticsXMLFileReader.h"
// Validation
#include
"otbSVMClassifier.h"
#include
"otbConfusionMatrixCalculator.h"
// Normalize the samples
#include
"otbShiftScaleSampleListFilter.h"
// List sample concatenation
#include
"otbConcatenateSampleListFilter.h"
// Classification filter
#include
"otbSVMImageClassificationFilter.h"
// Extract a ROI of the vectordata
#include
"otbVectorDataIntoImageProjectionFilter.h"
namespace
otb
{
namespace
Wrapper
{
class
ValidateSVMImagesClassifier
:
public
Application
{
public:
/** Standard class typedefs. */
typedef
ValidateSVMImagesClassifier
Self
;
typedef
Application
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
/** Standard macro */
itkNewMacro
(
Self
)
;
itkTypeMacro
(
ValidateSVMImagesClassifier
,
otb
::
Application
)
;
typedef
otb
::
Image
<
FloatVectorImageType
::
InternalPixelType
,
2
>
ImageReaderType
;
typedef
FloatVectorImageType
::
PixelType
PixelType
;
typedef
FloatVectorImageType
VectorImageType
;
typedef
FloatImageType
ImageType
;
typedef
Int32ImageType
LabeledImageType
;
// Training vectordata
typedef
itk
::
VariableLengthVector
<
ImageType
::
PixelType
>
MeasurementType
;
// SampleList manipulation
typedef
otb
::
ListSampleGenerator
<
VectorImageType
,
VectorDataType
>
ListSampleGeneratorType
;
typedef
ListSampleGeneratorType
::
ListSampleType
ListSampleType
;
typedef
ListSampleGeneratorType
::
LabelType
LabelType
;
typedef
ListSampleGeneratorType
::
ListLabelType
LabelListSampleType
;
typedef
otb
::
Statistics
::
ConcatenateSampleListFilter
<
ListSampleType
>
ConcatenateListSampleFilterType
;
typedef
otb
::
Statistics
::
ConcatenateSampleListFilter
<
LabelListSampleType
>
ConcatenateLabelListSampleFilterType
;
// Statistic XML file Reader
typedef
otb
::
StatisticsXMLFileReader
<
MeasurementType
>
StatisticsReader
;
// Enhance List Sample
typedef
otb
::
Statistics
::
ShiftScaleSampleListFilter
<
ListSampleType
,
ListSampleType
>
ShiftScaleFilterType
;
/// Classification typedefs
typedef
otb
::
SVMImageClassificationFilter
<
VectorImageType
,
LabeledImageType
>
ClassificationFilterType
;
typedef
ClassificationFilterType
::
Pointer
ClassificationFilterPointerType
;
typedef
ClassificationFilterType
::
ModelType
ModelType
;
typedef
ModelType
::
Pointer
ModelPointerType
;
typedef
otb
::
Functor
::
VariableLengthVectorToMeasurementVectorFunctor
<
MeasurementType
>
MeasurementVectorFunctorType
;
typedef
otb
::
SVMSampleListModelEstimator
<
ListSampleType
,
LabelListSampleType
,
MeasurementVectorFunctorType
>
SVMEstimatorType
;
// Estimate performance on validation sample
typedef
otb
::
SVMClassifier
<
ListSampleType
,
LabelType
::
ValueType
>
ClassifierType
;
typedef
otb
::
ConfusionMatrixCalculator
<
LabelListSampleType
,
LabelListSampleType
>
ConfusionMatrixCalculatorType
;
typedef
ClassifierType
::
OutputType
ClassifierOutputType
;
// Extract ROI and Project vectorData
typedef
otb
::
VectorDataIntoImageProjectionFilter
<
VectorDataType
,
VectorImageType
>
VectorDataReprojectionType
;
private:
ValidateSVMImagesClassifier
()
{
SetName
(
"ValidateSVMImagesClassifier"
);
SetDescription
(
"Perform SVM validation from multiple input images and multiple vector data."
);
}
virtual
~
ValidateSVMImagesClassifier
()
{
}
void
DoCreateParameters
()
{
AddParameter
(
ParameterType_InputImageList
,
"il"
,
"Input Image List"
);
AddParameter
(
ParameterType_InputVectorDataList
,
"vd"
,
"Vector Data of sample used to validate the estimator"
);
AddParameter
(
ParameterType_Filename
,
"dem"
,
"A DEM repository"
);
MandatoryOff
(
"dem"
);
AddParameter
(
ParameterType_Filename
,
"imstat"
,
"XML file containing mean and standard deviation of input images."
);
MandatoryOff
(
"imstat"
);
AddParameter
(
ParameterType_Filename
,
"svm"
,
"SVM model to validate its performances."
);
AddParameter
(
ParameterType_Filename
,
"out"
,
"File which will contain the performance of the SVM model."
);
}
void
DoUpdateParameters
()
{
// Nothing to do here : all parameters are independent
}
void
DoExecute
()
{
GetLogger
()
->
Debug
(
"Entering DoExecute
\n
"
);
//Create training and validation for list samples and label list samples
ConcatenateLabelListSampleFilterType
::
Pointer
concatenateTrainingLabels
=
ConcatenateLabelListSampleFilterType
::
New
();
ConcatenateListSampleFilterType
::
Pointer
concatenateTrainingSamples
=
ConcatenateListSampleFilterType
::
New
();
ConcatenateLabelListSampleFilterType
::
Pointer
concatenateValidationLabels
=
ConcatenateLabelListSampleFilterType
::
New
();
ConcatenateListSampleFilterType
::
Pointer
concatenateValidationSamples
=
ConcatenateListSampleFilterType
::
New
();
MeasurementType
meanMeasurementVector
;
MeasurementType
stddevMeasurementVector
;
//--------------------------
// Load measurements from images
unsigned
int
nbBands
=
0
;
//Iterate over all input images
FloatVectorImageListType
*
imageList
=
GetParameterImageList
(
"il"
);
VectorDataListType
*
vectorDataList
=
GetParameterVectorDataList
(
"vd"
);
//Iterate over all input images
for
(
unsigned
int
imgIndex
=
0
;
imgIndex
<
imageList
->
Size
();
++
imgIndex
)
{
FloatVectorImageType
::
Pointer
image
=
imageList
->
GetNthElement
(
imgIndex
);
image
->
UpdateOutputInformation
();
if
(
imgIndex
==
0
)
{
nbBands
=
image
->
GetNumberOfComponentsPerPixel
();
}
// read the Vectordata
VectorDataType
::
Pointer
vectorData
=
vectorDataList
->
GetNthElement
(
imgIndex
);
vectorData
->
Update
();
VectorDataReprojectionType
::
Pointer
vdreproj
=
VectorDataReprojectionType
::
New
();
vdreproj
->
SetInputImage
(
image
);
vdreproj
->
SetInput
(
vectorData
);
vdreproj
->
SetUseOutputSpacingAndOriginFromImage
(
false
);
// Configure DEM directory
if
(
HasUserValue
(
"dem"
))
{
vdreproj
->
SetDEMDirectory
(
GetParameterString
(
"dem"
));
}
else
{
if
(
otb
::
ConfigurationFile
::
GetInstance
()
->
IsValid
())
{
vdreproj
->
SetDEMDirectory
(
otb
::
ConfigurationFile
::
GetInstance
()
->
GetDEMDirectory
());
}
}
vdreproj
->
Update
();
//Sample list generator
ListSampleGeneratorType
::
Pointer
sampleGenerator
=
ListSampleGeneratorType
::
New
();
//Set inputs of the sample generator
//TODO the ListSampleGenerator perform UpdateOutputData over the input image (need a persistent implementation)
sampleGenerator
->
SetInput
(
image
);
sampleGenerator
->
SetInputVectorData
(
vdreproj
->
GetOutput
());
sampleGenerator
->
SetValidationTrainingProportion
(
1.0
);
// All in validation
sampleGenerator
->
SetClassKey
(
"Class"
);
sampleGenerator
->
Update
();
//Concatenate training and validation samples from the image
concatenateValidationLabels
->
AddInput
(
sampleGenerator
->
GetValidationListLabel
());
concatenateValidationSamples
->
AddInput
(
sampleGenerator
->
GetValidationListSample
());
}
// Update
concatenateValidationSamples
->
Update
();
concatenateValidationLabels
->
Update
();
if
(
HasValue
(
"imstat"
))
{
StatisticsReader
::
Pointer
statisticsReader
=
StatisticsReader
::
New
();
statisticsReader
->
SetFileName
(
GetParameterString
(
"imstat"
));
meanMeasurementVector
=
statisticsReader
->
GetStatisticVectorByName
(
"mean"
);
stddevMeasurementVector
=
statisticsReader
->
GetStatisticVectorByName
(
"stddev"
);
}
else
{
meanMeasurementVector
.
SetSize
(
nbBands
);
meanMeasurementVector
.
Fill
(
0.
);
stddevMeasurementVector
.
SetSize
(
nbBands
);
stddevMeasurementVector
.
Fill
(
1.
);
}
ShiftScaleFilterType
::
Pointer
validationShiftScaleFilter
=
ShiftScaleFilterType
::
New
();
validationShiftScaleFilter
->
SetInput
(
concatenateValidationSamples
->
GetOutput
());
validationShiftScaleFilter
->
SetShifts
(
meanMeasurementVector
);
validationShiftScaleFilter
->
SetScales
(
stddevMeasurementVector
);
validationShiftScaleFilter
->
Update
();
//--------------------------
// split the data set into training/validation set
ListSampleType
::
Pointer
validationListSample
=
validationShiftScaleFilter
->
GetOutputSampleList
();
LabelListSampleType
::
Pointer
validationLabeledListSample
=
concatenateValidationLabels
->
GetOutputSampleList
();
otbAppLogINFO
(
"Size of validation set: "
<<
validationListSample
->
Size
());
otbAppLogINFO
(
"Size of labeled validation set: "
<<
validationLabeledListSample
->
Size
());
//--------------------------
// Load svm model
ModelPointerType
modelSVM
=
ModelType
::
New
();
modelSVM
->
LoadModel
(
GetParameterString
(
"svm"
).
c_str
());
//--------------------------
// Performances estimation
ClassifierType
::
Pointer
validationClassifier
=
ClassifierType
::
New
();
validationClassifier
->
SetSample
(
validationListSample
);
validationClassifier
->
SetNumberOfClasses
(
modelSVM
->
GetNumberOfClasses
());
validationClassifier
->
SetModel
(
modelSVM
);
validationClassifier
->
Update
();
// Estimate performances
ClassifierOutputType
::
ConstIterator
it
=
validationClassifier
->
GetOutput
()
->
Begin
();
ClassifierOutputType
::
ConstIterator
itEnd
=
validationClassifier
->
GetOutput
()
->
End
();
LabelListSampleType
::
Pointer
classifierListLabel
=
LabelListSampleType
::
New
();
while
(
it
!=
itEnd
)
{
// Due to a bug in SVMClassifier, outlier in one-class SVM are labeled with unsigned int max
classifierListLabel
->
PushBack
(
it
.
GetClassLabel
()
==
itk
::
NumericTraits
<
unsigned
int
>::
max
()
?
2
:
it
.
GetClassLabel
());
++
it
;
}
ConfusionMatrixCalculatorType
::
Pointer
confMatCalc
=
ConfusionMatrixCalculatorType
::
New
();
confMatCalc
->
SetReferenceLabels
(
validationLabeledListSample
);
confMatCalc
->
SetProducedLabels
(
classifierListLabel
);
confMatCalc
->
Update
();
otbAppLogINFO
(
"*** SVM training performances ***
\n
"
<<
"Confusion matrix:
\n
"
<<
confMatCalc
->
GetConfusionMatrix
()
<<
std
::
endl
);
for
(
unsigned
int
itClasses
=
0
;
itClasses
<
modelSVM
->
GetNumberOfClasses
();
itClasses
++
)
{
otbAppLogINFO
(
"Precision of class ["
<<
itClasses
<<
"] vs all: "
<<
confMatCalc
->
GetPrecisions
()[
itClasses
]
<<
std
::
endl
);
otbAppLogINFO
(
"Recall of class ["
<<
itClasses
<<
"] vs all: "
<<
confMatCalc
->
GetRecalls
()[
itClasses
]
<<
std
::
endl
);
otbAppLogINFO
(
"F-score of class ["
<<
itClasses
<<
"] vs all: "
<<
confMatCalc
->
GetFScores
()[
itClasses
]
<<
"
\n
"
<<
std
::
endl
);
}
otbAppLogINFO
(
"Global performance, Kappa index: "
<<
confMatCalc
->
GetKappaIndex
()
<<
std
::
endl
);
//--------------------------
// Save output in a ascii file (if needed)
if
(
IsParameterEnabled
(
"out"
))
{
std
::
ofstream
file
;
file
.
open
(
GetParameterString
(
"out"
).
c_str
());
file
<<
"Precision of the different class: "
<<
confMatCalc
->
GetPrecisions
()
<<
std
::
endl
;
file
<<
"Recall of the different class: "
<<
confMatCalc
->
GetRecalls
()
<<
std
::
endl
;
file
<<
"F-score of the different class: "
<<
confMatCalc
->
GetFScores
()
<<
std
::
endl
;
file
<<
"Kappa index: "
<<
confMatCalc
->
GetKappaIndex
()
<<
std
::
endl
;
file
.
close
();
}
}
};
}
// end of namespace Wrapper
}
// end of namespace otb
OTB_APPLICATION_EXPORT
(
otb
::
Wrapper
::
ValidateSVMImagesClassifier
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment