diff --git a/Modules/Radiometry/Simulation/test/otbImageSimulationMethodSVMClassif.cxx b/Modules/Radiometry/Simulation/test/otbImageSimulationMethodSVMClassif.cxx index 6148679d2112bd2a43b2e4e4c1eb6e7398a4c723..639950e4b42b4455d56c4c1f3aa233e8ad4dce28 100644 --- a/Modules/Radiometry/Simulation/test/otbImageSimulationMethodSVMClassif.cxx +++ b/Modules/Radiometry/Simulation/test/otbImageSimulationMethodSVMClassif.cxx @@ -32,6 +32,7 @@ #include "otbLibSVMMachineLearningModel.h" #include "otbImageClassificationFilter.h" #include "otbImageFileReader.h" +#include "itkImageToListSampleAdaptor.h" int otbImageSimulationMethodSVMClassif(int itkNotUsed(argc), char * argv[]) { @@ -138,38 +139,18 @@ int otbImageSimulationMethodSVMClassif(int itkNotUsed(argc), char * argv[]) //~ svmEstimator->DoProbabilityEstimates(true); //~ svmEstimator->Update(); - OutputImageType::Pointer outReflectance = imageSimulation->GetOutputReflectanceImage(); - LabelImageType::Pointer outLabels = imageSimulation->GetOutputLabelImage(); - - typedef SVMType::InputListSampleType InputListSampleType; - typedef SVMType::TargetListSampleType TargetListSampleType; - InputListSampleType::Pointer inputSamples = InputListSampleType::New(); - TargetListSampleType::Pointer trainSamples = TargetListSampleType::New(); - inputSamples->SetMeasurementVectorSize(nbBands); - trainSamples->SetMeasurementVectorSize(1); - itk::ImageRegionConstIterator<OutputImageType> itIn(outReflectance,outReflectance->GetLargestPossibleRegion() ); - itk::ImageRegionConstIterator<LabelImageType> itLabel(outLabels, outLabels->GetLargestPossibleRegion()); - itIn.GoToBegin(); - itLabel.GoToBegin(); - while (!itIn.IsAtEnd()) - { - SVMType::InputSampleType sample; - SVMType::TargetSampleType target; - sample.SetSize(nbBands); - for (unsigned int i=0 ; i<nbBands ; i++) - { - sample[i] = itIn.Get()[i]; - } - target[0] = itLabel.Value(); - inputSamples->PushBack(sample); - trainSamples->PushBack(target); - ++itIn; - ++itLabel; - } - - model->SetInputListSample(inputSamples); - model->SetTargetListSample(trainSamples); - model->DoProbabilityEstimates(true); + typedef itk::Statistics::ImageToListSampleAdaptor<OutputImageType> ListSampleAdaptorType; + typedef itk::Statistics::ImageToListSampleAdaptor<LabelImageType> TargetListSampleAdaptorType; + + ListSampleAdaptorType::Pointer listSample = ListSampleAdaptorType::New(); + listSample->SetImage(imageSimulation->GetOutputReflectanceImage()); + + TargetListSampleAdaptorType::Pointer targetListSample = TargetListSampleAdaptorType::New(); + targetListSample->SetImage(imageSimulation->GetOutputLabelImage()); + + model->SetInputListSample(listSample); + model->SetTargetListSample(targetListSample); + model->SetDoProbabilityEstimates(true); model->Train(); classifier->SetModel(model);