diff --git a/Testing/Code/ObjectDetection/otbDescriptorsListSampleGenerator.cxx b/Testing/Code/ObjectDetection/otbDescriptorsListSampleGenerator.cxx index 70fdb7bc3ba4868a7c8998740a8236c2327827ab..b8180a86c51ee665f7cb3311b49780640a966681 100644 --- a/Testing/Code/ObjectDetection/otbDescriptorsListSampleGenerator.cxx +++ b/Testing/Code/ObjectDetection/otbDescriptorsListSampleGenerator.cxx @@ -32,6 +32,8 @@ #include "otbVectorDataFileReader.h" #include "otbImageFunctionAdaptor.h" +#include "otbStatisticsXMLFileReader.h" +#include "otbShiftScaleSampleListFilter.h" #include "otbSVMSampleListModelEstimator.h" const unsigned int Dimension = 2; @@ -65,6 +67,9 @@ typedef otb::VectorDataFileReader<VectorDataType> VectorDataReaderType; typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<SampleType> MeasurementVectorFunctorType; +typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader; +typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType> ShiftScaleListSampleFilterType; + typedef otb::SVMSampleListModelEstimator< ListSampleType, LabelListSampleType, @@ -204,7 +209,7 @@ int otbDescriptorsListSampleGenerator(int argc, char* argv[]) int otbDescriptorsSVMModelCreation(int argc, char* argv[]) { - if (argc != 6) + if (argc != 7) { std::cerr << "Wrong number of arguments" << std::endl; return EXIT_FAILURE; @@ -212,9 +217,10 @@ int otbDescriptorsSVMModelCreation(int argc, char* argv[]) const char* inputImageFileName = argv[1]; const char* inputSamplesLocation = argv[2]; - const char* outputFileName = argv[3]; - int streaming = atoi(argv[4]); - int neighborhood = atoi(argv[5]); + const char* featureStatisticsFileName = argv[3]; + const char* outputFileName = argv[4]; + int streaming = atoi(argv[5]); + int neighborhood = atoi(argv[6]); ImageReaderType::Pointer imageReader = ImageReaderType::New(); imageReader->SetFileName(inputImageFileName); @@ -246,9 +252,22 @@ int otbDescriptorsSVMModelCreation(int argc, char* argv[]) descriptorsGenerator->Update(); - SVMEstimatorType::Pointer svmEstimator = SVMEstimatorType::New(); + // Normalize the samples + // Read the mean and variance form the XML file + StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); + statisticsReader->SetFileName(featureStatisticsFileName); + SampleType meanMeasurentVector = statisticsReader->GetStatisticVectorByName("mean"); + SampleType varianceMeasurentVector = statisticsReader->GetStatisticVectorByName("stddev"); - svmEstimator->SetInputSampleList(descriptorsGenerator->GetListSample()); + // Shift scale the samples + ShiftScaleListSampleFilterType::Pointer shiftscaleFilter = ShiftScaleListSampleFilterType::New(); + shiftscaleFilter->SetInput(descriptorsGenerator->GetListSample()); + shiftscaleFilter->SetShifts(meanMeasurentVector); + shiftscaleFilter->SetScales(varianceMeasurentVector); + shiftscaleFilter->Update(); + + SVMEstimatorType::Pointer svmEstimator = SVMEstimatorType::New(); + svmEstimator->SetInputSampleList(shiftscaleFilter->GetOutputSampleList()); svmEstimator->SetTrainingSampleList(descriptorsGenerator->GetLabelListSample()); svmEstimator->Update(); svmEstimator->GetModel()->SaveModel(outputFileName);