Commit 16e6a638 authored by Guillaume Pasero's avatar Guillaume Pasero

REFAC: use new LibSVM classifier in Simulation module

parent c5865915
......@@ -47,7 +47,7 @@ ENABLE_SHARED
OTBTestKernel
OTBLearningBase
OTBSupervised
OTBSVMLearning
OTBLibSVM
OTBVectorDataIO
DESCRIPTION
......
......@@ -24,8 +24,7 @@
#include "otbSatelliteRSR.h"
#include "otbReduceSpectralResponse.h"
#include "otbSVMSampleListModelEstimator.h"
#include "otbSVMClassifier.h"
#include "otbLibSVMMachineLearningModel.h"
#include "otbConfusionMatrixCalculator.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
......@@ -66,14 +65,12 @@ int otbAtmosphericCorrectionsRSRSVMClassifier(int argc, char * argv[])
typedef itk::VariableLengthVector<double> SampleType;
typedef itk::Statistics::ListSample<SampleType> SampleListType;
typedef itk::FixedArray<unsigned long, 1> TrainingSampleType;
typedef itk::Statistics::ListSample<TrainingSampleType> TrainingSampleListType;
typedef itk::FixedArray<unsigned long, 1> TargetSampleType;
typedef itk::Statistics::ListSample<TargetSampleType> TargetSampleListType;
typedef otb::SVMSampleListModelEstimator<SampleListType, TrainingSampleListType> SVMModelEstimatorType;
typedef otb::SVMClassifier<SampleListType, unsigned long> SVMClassifierType;
typedef SVMClassifierType::OutputType ClassifierOutputType;
typedef otb::LibSVMMachineLearningModel<double, unsigned long> SVMType;
typedef otb::ConfusionMatrixCalculator<TrainingSampleListType, TrainingSampleListType> ConfusionMatrixCalculatorType;
typedef otb::ConfusionMatrixCalculator<TargetSampleListType, TargetSampleListType> ConfusionMatrixCalculatorType;
if (argc != 20)
{
......@@ -223,7 +220,7 @@ int otbAtmosphericCorrectionsRSRSVMClassifier(int argc, char * argv[])
//compute spectral response for all training files
SampleListType::Pointer sampleList = SampleListType::New();
TrainingSampleListType::Pointer trainingList = TrainingSampleListType::New();
TargetSampleListType::Pointer trainingList = TargetSampleListType::New();
for (unsigned int i = 0; i < trainingFiles.size(); ++i)
{
......@@ -246,7 +243,7 @@ int otbAtmosphericCorrectionsRSRSVMClassifier(int argc, char * argv[])
//Get the response in an itk::VariableLengthVector and add it to the sample list for SVMModelEstimator
SampleType sample;
TrainingSampleType trainingSample;
TargetSampleType trainingSample;
sample.SetSize(atmosphericEffectsFilter->GetCorrectedSpectralResponse()->Size());
std::cout << "corrected response : [";
for (unsigned int j = 0; j < atmosphericEffectsFilter->GetCorrectedSpectralResponse()->Size(); ++j)
......@@ -262,16 +259,16 @@ int otbAtmosphericCorrectionsRSRSVMClassifier(int argc, char * argv[])
}
//SVM model estimator
SVMModelEstimatorType::Pointer estimator = SVMModelEstimatorType::New();
estimator->SetInputSampleList(sampleList);
estimator->SetTrainingSampleList(trainingList);
estimator->DoProbabilityEstimates(true);
estimator->Update();
estimator->GetModel()->SaveModel("model.txt");
SVMType::Pointer classifier = SVMType::New();
classifier->SetInputListSample(sampleList);
classifier->SetTargetListSample(trainingList);
classifier->DoProbabilityEstimates(true);
classifier->Train();
classifier->Save("model.txt");
//compute spectral response for testing files
sampleList->Clear(); //clear the sample list to re use it for testing samples
TrainingSampleListType::Pointer groundTruthClassList = TrainingSampleListType::New();
TargetSampleListType::Pointer groundTruthClassList = TargetSampleListType::New();
for (unsigned int i = 0; i < testingFiles.size(); ++i)
{
SpectralResponsePointerType spectralResponse = SpectralResponseType::New();
......@@ -293,7 +290,7 @@ int otbAtmosphericCorrectionsRSRSVMClassifier(int argc, char * argv[])
//Get the response in an itk::VariableLengthVector and add it to the sample list for SVMClassifier
SampleType sample;
TrainingSampleType gtClass;
TargetSampleType gtClass;
sample.SetSize(atmosphericEffectsFilter->GetCorrectedSpectralResponse()->Size());
for (unsigned int j = 0; j < atmosphericEffectsFilter->GetCorrectedSpectralResponse()->Size(); ++j)
{
......@@ -305,21 +302,8 @@ int otbAtmosphericCorrectionsRSRSVMClassifier(int argc, char * argv[])
}
//SVM Classifier
SVMClassifierType::Pointer classifier = SVMClassifierType::New();
classifier->SetModel(estimator->GetModel());
classifier->SetInput(sampleList);
classifier->SetNumberOfClasses(dirSR.size());
classifier->Update();
TargetSampleListType::Pointer classifierListLabel = classifier->PredictBatch(sampleList);
ClassifierOutputType::ConstIterator it = classifier->GetOutput()->Begin();
TrainingSampleListType::Pointer classifierListLabel = TrainingSampleListType::New();
while (it != classifier->GetOutput()->End())
{
std::cout << "class : " << it.GetClassLabel() << std::endl;
classifierListLabel->PushBack(it.GetClassLabel());
++it;
}
for (unsigned int i = 0; i < testingFiles.size(); ++i)
{
std::cout << "ground truth class : " << testingGTClasses[i] << std::endl;
......
......@@ -29,8 +29,8 @@
#include "otbSpatialisationFilter.h"
#include "otbImageSimulationMethod.h"
#include "otbAttributesMapLabelObject.h"
#include "otbSVMImageModelEstimator.h"
#include "otbSVMImageClassificationFilter.h"
#include "otbLibSVMMachineLearningModel.h"
#include "otbImageClassificationFilter.h"
#include "otbImageFileReader.h"
int otbImageSimulationMethodSVMClassif(int itkNotUsed(argc), char * argv[])
......@@ -62,17 +62,16 @@ int otbImageSimulationMethodSVMClassif(int itkNotUsed(argc), char * argv[])
typedef otb::ImageSimulationMethod<VectorDataType, SpatialisationFilterType,
SimulationStep1Type, SimulationStep2Type, FTMType , OutputImageType> ImageSimulationMethodType;
typedef otb::SVMImageModelEstimator<OutputImageType, LabelImageType> SVMEstimatorType;
typedef otb::SVMImageClassificationFilter<OutputImageType, LabelImageType> SVMClassificationFilterType;
typedef otb::LibSVMMachineLearningModel<double, unsigned short> SVMType;
typedef otb::ImageClassificationFilter<OutputImageType,LabelImageType> ClassificationFilterType;
/** Instantiation of pointer objects*/
ImageWriterType::Pointer writer = ImageWriterType::New();
LabelImageWriterType::Pointer labelWriter = LabelImageWriterType::New();
ImageSimulationMethodType::Pointer imageSimulation = ImageSimulationMethodType::New();
SpatialisationFilterType::Pointer spatialisationFilter = SpatialisationFilterType::New();
SVMEstimatorType::Pointer svmEstimator = SVMEstimatorType::New();
SVMClassificationFilterType::Pointer classifier = SVMClassificationFilterType::New();
SVMType::Pointer model = SVMType::New();
ClassificationFilterType::Pointer classifier = ClassificationFilterType::New();
SpatialisationFilterType::SizeType objectSize;
......@@ -132,15 +131,48 @@ int otbImageSimulationMethodSVMClassif(int itkNotUsed(argc), char * argv[])
// imageSimulation->SetVariance();
imageSimulation->UpdateData();
svmEstimator->SetInputImage(imageSimulation->GetOutputReflectanceImage());
svmEstimator->SetTrainingImage(imageSimulation->GetOutputLabelImage());
svmEstimator->SetParametersOptimization(false);
svmEstimator->DoProbabilityEstimates(true);
svmEstimator->Update();
classifier->SetModel(svmEstimator->GetModel());
//~ svmEstimator->SetInputImage(imageSimulation->GetOutputReflectanceImage());
//~ svmEstimator->SetTrainingImage(imageSimulation->GetOutputLabelImage());
//~ svmEstimator->SetParametersOptimization(false);
//~ svmEstimator->DoProbabilityEstimates(true);
//~ svmEstimator->Update();
OutputImageType::Pointer outReflectance = imageSimulation->GetOutputReflectanceImage();
LabelImageType::Pointer outLabels = imageSimulation->GetOutputLabelImage();
typedef SVMType::InputListSampleType InputListSampleType;
typedef SVMType::TargetListSampleType TargetListSampleType;
InputListSampleType::Pointer inputSamples = InputListSampleType::New();
TargetListSampleType::Pointer trainSamples = TargetListSampleType::New();
inputSamples->SetMeasurementVectorSize(nbBands);
trainSamples->SetMeasurementVectorSize(1);
itk::ImageRegionConstIterator<OutputImageType> itIn(outReflectance,outReflectance->GetLargestPossibleRegion() );
itk::ImageRegionConstIterator<LabelImageType> itLabel(outLabels, outLabels->GetLargestPossibleRegion());
itIn.GoToBegin();
itLabel.GoToBegin();
while (!itIn.IsAtEnd())
{
SVMType::InputSampleType sample;
SVMType::TargetSampleType target;
sample.SetSize(nbBands);
for (unsigned int i=0 ; i<nbBands ; i++)
{
sample[i] = itIn.Get()[i];
}
target[0] = itLabel.Value();
inputSamples->PushBack(sample);
trainSamples->PushBack(target);
++itIn;
++itLabel;
}
model->SetInputListSample(inputSamples);
model->SetTargetListSample(trainSamples);
model->DoProbabilityEstimates(true);
model->Train();
classifier->SetModel(model);
classifier->SetInput(imageSimulation->GetOutput());
//Write the result to an image file
......
......@@ -23,8 +23,7 @@
#include "otbSatelliteRSR.h"
#include "otbReduceSpectralResponse.h"
#include "otbSVMSampleListModelEstimator.h"
#include "otbSVMClassifier.h"
#include "otbLibSVMMachineLearningModel.h"
#include "otbConfusionMatrixCalculator.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
......@@ -47,9 +46,7 @@ int otbReduceSpectralResponseSVMClassifier(int argc, char * argv[])
typedef itk::FixedArray<unsigned long, 1> TrainingSampleType;
typedef itk::Statistics::ListSample<TrainingSampleType> TrainingSampleListType;
typedef otb::SVMSampleListModelEstimator<SampleListType, TrainingSampleListType> SVMModelEstimatorType;
typedef otb::SVMClassifier<SampleListType, unsigned long> SVMClassifierType;
typedef SVMClassifierType::OutputType ClassifierOutputType;
typedef otb::LibSVMMachineLearningModel<double, unsigned long> SVMType;
typedef otb::ConfusionMatrixCalculator<TrainingSampleListType, TrainingSampleListType> ConfusionMatrixCalculatorType;
......@@ -174,19 +171,19 @@ int otbReduceSpectralResponseSVMClassifier(int argc, char * argv[])
}
//SVM model estimator
SVMModelEstimatorType::Pointer estimator = SVMModelEstimatorType::New();
estimator->SetInputSampleList(sampleList);
estimator->SetTrainingSampleList(trainingList);
estimator->SetNu(0.5);
estimator->SetKernelGamma(1);
estimator->SetKernelCoef0(1);
estimator->SetC(1);
estimator->SetEpsilon(0.001);
estimator->SetP(0.1);
estimator->DoProbabilityEstimates(true);
estimator->Update();
estimator->GetModel()->SaveModel("model.txt");
SVMType::Pointer model = SVMType::New();
model->SetInputListSample(sampleList);
model->SetTargetListSample(trainingList);
model->SetNu(0.5);
model->SetKernelGamma(1);
model->SetKernelCoef0(1);
model->SetC(1);
model->SetEpsilon(0.001);
model->SetP(0.1);
model->DoProbabilityEstimates(true);
model->Train();
model->Save("model.txt");
//compute spectral response for testing files
sampleList->Clear(); //clear the sample list to re use it for testing samples
......@@ -219,19 +216,13 @@ int otbReduceSpectralResponseSVMClassifier(int argc, char * argv[])
}
//SVM Classifier
SVMClassifierType::Pointer classifier = SVMClassifierType::New();
classifier->SetModel(estimator->GetModel());
classifier->SetInput(sampleList);
classifier->SetNumberOfClasses(dirSR.size());
classifier->Update();
TrainingSampleListType::Pointer classifierListLabel =
model->PredictBatch(sampleList);
ClassifierOutputType::ConstIterator it = classifier->GetOutput()->Begin();
TrainingSampleListType::Pointer classifierListLabel = TrainingSampleListType::New();
while (it != classifier->GetOutput()->End())
TrainingSampleListType::ConstIterator it = classifierListLabel->Begin();
while (it != classifierListLabel->End())
{
std::cout << "class : " << it.GetClassLabel() << std::endl;
classifierListLabel->PushBack(it.GetClassLabel());
std::cout << "class : " << it.GetMeasurementVector()[0] << std::endl;
++it;
}
for (unsigned int i = 0; i < testingFiles.size(); ++i)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment