Skip to content
Snippets Groups Projects
Commit 2c33658b authored by Julien Michel's avatar Julien Michel
Browse files

TEST: Adding two tests to show that SVM does not behave correctly with probability estimates

parent a0573d3f
No related branches found
No related tags found
No related merge requests found
......@@ -544,6 +544,12 @@ SET_TESTS_PROPERTIES(leTvConfusionMatrixCalculatorWrongSize PROPERTIES WILL_FAIL
ADD_TEST(leTvConfusionMatrixCalculatorUpdate ${LEARNING_TESTS4}
otbConfusionMatrixCalculatorUpdate 1000 4)
# ------- SVM Validation -------------------------
ADD_TEST(leTvSVMValidationLinearlySeparableWithoutProbEstimate ${LEARNING_TESTS4}
otbSVMValidation 500 500 0.0025 0.0075 0.0075 0.0025 0. 0.0025 0. 0.0025 0 0)
ADD_TEST(leTvSVMValidationLinearlySeparableWithProbEstimate ${LEARNING_TESTS4}
otbSVMValidation 500 500 0.0025 0.0075 0.0075 0.0025 0. 0.0025 0. 0.0025 0 1)
# A enrichir
SET(BasicLearning_SRCS1
......@@ -609,6 +615,7 @@ otbSVMCrossValidationCostFunctionNew.cxx
otbExhaustiveExponentialOptimizerNew.cxx
otbListSampleGeneratorTest.cxx
otbConfusionMatrixCalculatorTest.cxx
otbSVMValidation.cxx
)
OTB_ADD_EXECUTABLE(otbLearningTests1 "${BasicLearning_SRCS1}" "OTBLearning;OTBIO;OTBTesting")
......
......@@ -36,4 +36,5 @@ void RegisterTests()
REGISTER_TEST(otbConfusionMatrixCalculatorSetListSamples);
REGISTER_TEST(otbConfusionMatrixCalculatorWrongSize);
REGISTER_TEST(otbConfusionMatrixCalculatorUpdate);
REGISTER_TEST(otbSVMValidation);
}
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#if defined(_MSC_VER)
#pragma warning ( disable : 4786 )
#endif
#include "itkExceptionObject.h"
#include "itkListSample.h"
#include <iostream>
#include "otbSVMSampleListModelEstimator.h"
#include "otbSVMClassifier.h"
#include "otbSVMKernels.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
#include "otbSVMClassifier.h"
#include "otbConfusionMatrixCalculator.h"
#include <fstream>
int otbSVMValidation(int argc, char* argv[])
{
if(argc != 13)
{
std::cerr<<"Usage: "<<argv[0]<<" nbTrainingSamples nbValidationSamples positiveCenterX positiveCenterY negativeCenterX negativeCenterY positiveRadiusMin positiveRadiusMax negativeRadiusMin negativeRadiusMax kernel probEstimate"<<std::endl;
return EXIT_FAILURE;
}
unsigned int nbTrainingSamples = atoi(argv[1]);
unsigned int nbValidationSamples = atoi(argv[2]);
double cpx = atof(argv[3]);
double cpy = atof(argv[4]);
double cnx = atof(argv[5]);
double cny = atof(argv[6]);
double prmin = atof(argv[7]);
double prmax = atof(argv[8]);
double nrmin = atof(argv[9]);
double nrmax = atof(argv[10]);
unsigned int kernel = atoi(argv[11]);
bool probEstimate = atoi(argv[12]);
typedef double InputPixelType;
typedef unsigned short LabelType;
typedef itk::VariableLengthVector<InputPixelType> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::FixedArray<LabelType, 1> TrainingSampleType;
typedef itk::Statistics::ListSample<TrainingSampleType> TrainingListSampleType;
typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType;
typedef otb::SVMSampleListModelEstimator<ListSampleType,TrainingListSampleType> EstimatorType;
typedef otb::SVMClassifier<ListSampleType, LabelType> ClassifierType;
typedef ClassifierType::OutputType ClassifierOutputType;
typedef otb::ConfusionMatrixCalculator
<TrainingListSampleType,TrainingListSampleType> ConfusionMatrixCalculatorType;
RandomGeneratorType::Pointer random = RandomGeneratorType::New();
random->SetSeed((unsigned int)0);
// First, generate training and validation sets
ListSampleType::Pointer trainingSamples = ListSampleType::New();
TrainingListSampleType::Pointer trainingLabels = TrainingListSampleType::New();
ListSampleType::Pointer validationSamples = ListSampleType::New();
TrainingListSampleType::Pointer validationLabels = TrainingListSampleType::New();
// Generate training set
// std::ofstream training("training.csv");
for(unsigned int i =0; i < nbTrainingSamples; ++i)
{
// Generate a positive sample
double angle = random->GetVariateWithOpenUpperRange(2*M_PI);
double radius = random->GetUniformVariate(prmin,prmax);
SampleType pSample(2);
pSample[0] = cpx+radius*vcl_sin(angle);
pSample[1] = cpy+radius*vcl_cos(angle);
TrainingSampleType label;
label[0]=1;
trainingSamples->PushBack(pSample);
trainingLabels->PushBack(label);
// training<<"1 1:"<<pSample[0]<<" 2:"<<pSample[1]<<std::endl;
// Generate a negative sample
angle = random->GetVariateWithOpenUpperRange(2*M_PI);
radius = random->GetUniformVariate(nrmin,nrmax);
SampleType nSample(2);
nSample[0] = cnx+radius*vcl_sin(angle);
nSample[1] = cny+radius*vcl_cos(angle);
label[0]=2;
trainingSamples->PushBack(nSample);
trainingLabels->PushBack(label);
// training<<"2 1:"<<nSample[0]<<" 2:"<<nSample[1]<<std::endl;
}
// training.close();
// Generate validation set
// std::ofstream validation("validation.csv");
for(unsigned int i =0; i < nbValidationSamples; ++i)
{
// Generate a positive sample
double angle = random->GetVariateWithOpenUpperRange(2*M_PI);
double radius = random->GetUniformVariate(prmin,prmax);
SampleType pSample(2);
pSample[0] = cpx+radius*vcl_sin(angle);
pSample[1] = cpy+radius*vcl_cos(angle);
TrainingSampleType label;
label[0]=1;
validationSamples->PushBack(pSample);
validationLabels->PushBack(label);
// validation<<"1 1:"<<pSample[0]<<" 2:"<<pSample[1]<<std::endl;
// Generate a negative sample
angle = random->GetVariateWithOpenUpperRange(2*M_PI);
radius = random->GetUniformVariate(nrmin,nrmax);
SampleType nSample(2);
nSample[0] = cnx+radius*vcl_sin(angle);
nSample[1] = cny+radius*vcl_cos(angle);
label[0]=2;
validationSamples->PushBack(nSample);
validationLabels->PushBack(label);
// validation<<"2 1:"<<nSample[0]<<" 2:"<<nSample[1]<<std::endl;
}
// validation.close();
// Learn
EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetInputSampleList(trainingSamples);
estimator->SetTrainingSampleList(trainingLabels);
estimator->SetKernelType(kernel);
estimator->DoProbabilityEstimates(probEstimate);
// estimator->SetParametersOptimization(true);
estimator->Update();
// estimator->SaveModel("model.svm");
// Classify
ClassifierType::Pointer validationClassifier = ClassifierType::New();
validationClassifier->SetSample(validationSamples);
validationClassifier->SetNumberOfClasses(2);
validationClassifier->SetModel(estimator->GetModel());
validationClassifier->Update();
// Confusion
ClassifierOutputType::ConstIterator it = validationClassifier->GetOutput()->Begin();
ClassifierOutputType::ConstIterator itEnd = validationClassifier->GetOutput()->End();
TrainingListSampleType::Pointer classifierListLabel = TrainingListSampleType::New();
while (it != itEnd)
{
classifierListLabel->PushBack(it.GetClassLabel());
++it;
}
ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
confMatCalc->SetReferenceLabels(validationLabels);
confMatCalc->SetProducedLabels(classifierListLabel);
confMatCalc->Update();
std::cout<<std::endl;
std::cout<<"Confusion matrix: "<<std::endl<< confMatCalc->GetConfusionMatrix()<<std::endl<<std::endl;
std::cout<<"Kappa Index: "<<std::endl<< confMatCalc->GetKappaIndex()<<std::endl<<std::endl;
if(confMatCalc->GetKappaIndex()!=1)
{
std::cerr<<"Kappa index should be 1."<<std::endl;
return EXIT_FAILURE;
}
else
{
return EXIT_SUCCESS;
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment