Skip to content
Snippets Groups Projects
Commit 536dcdbd authored by Julien Malik's avatar Julien Malik
Browse files

ADD: add an svm creation test for Object Detection

parent 0aaab2a3
No related merge requests found
......@@ -32,6 +32,8 @@
#include "otbVectorDataFileReader.h"
#include "otbImageFunctionAdaptor.h"
#include "otbSVMSampleListModelEstimator.h"
const unsigned int Dimension = 2;
typedef int LabelType;
typedef double PixelType;
......@@ -41,14 +43,13 @@ typedef double CoordRepType;
typedef otb::Image<PixelType, Dimension> ImageType;
typedef otb::VectorData<> VectorDataType;
typedef otb::RadiometricMomentsImageFunction<ImageType, CoordRepType> FunctionType;
typedef otb::ImageFunctionAdaptor
<FunctionType> AdapatedFunctionType;
typedef otb::ImageFunctionAdaptor<FunctionType> AdapatedFunctionType;
//typedef FunctionType::OutputType SampleType;
typedef itk::VariableLengthVector<CoordRepType> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::FixedArray<LabelType> LabelSampleType;
typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType;
typedef itk::VariableLengthVector<FunctionPrecisionType> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::FixedArray<LabelType> LabelSampleType;
typedef itk::Statistics::ListSample<LabelSampleType> LabelListSampleType;
typedef otb::DescriptorsListSampleGenerator
< ImageType,
......@@ -61,10 +62,18 @@ typedef otb::DescriptorsListSampleGenerator
typedef otb::ImageFileReader<ImageType> ImageReaderType;
typedef otb::VectorDataFileReader<VectorDataType> VectorDataReaderType;
typedef otb::Functor::VariableLengthVectorToMeasurementVectorFunctor<SampleType>
MeasurementVectorFunctorType;
typedef otb::SVMSampleListModelEstimator<
ListSampleType,
LabelListSampleType,
MeasurementVectorFunctorType> SVMEstimatorType;
typedef FunctionType::PointType PointType;
typedef DescriptorsListSampleGeneratorType::SamplesPositionType SamplesPositionType;
struct SampleEntry
{
PointType position;
......@@ -112,7 +121,7 @@ int otbDescriptorsListSampleGeneratorNew(int itkNotUsed(argc), char* itkNotUsed(
int otbDescriptorsListSampleGenerator(int argc, char* argv[])
{
if (argc != 5)
if (argc != 6)
{
std::cerr << "Wrong number of arguments" << std::endl;
return EXIT_FAILURE;
......@@ -122,6 +131,7 @@ int otbDescriptorsListSampleGenerator(int argc, char* argv[])
const char* inputSamplesLocation = argv[2];
const char* outputFileName = argv[3];
int streaming = atoi(argv[4]);
int neighborhood = atoi(argv[5]);
ImageReaderType::Pointer imageReader = ImageReaderType::New();
imageReader->SetFileName(inputImageFileName);
......@@ -140,7 +150,7 @@ int otbDescriptorsListSampleGenerator(int argc, char* argv[])
descriptorsGenerator->SetInputImage(imageReader->GetOutput());
descriptorsGenerator->SetSamplesLocations(vectorDataReader->GetOutput());
descriptorsGenerator->SetDescriptorsFunction(descriptorsFunction.GetPointer());
descriptorsGenerator->SetNeighborhoodRadius(5);
descriptorsGenerator->SetNeighborhoodRadius(neighborhood);
if (streaming == 0)
{
......@@ -189,3 +199,61 @@ int otbDescriptorsListSampleGenerator(int argc, char* argv[])
return EXIT_SUCCESS;
}
int otbDescriptorsSVMModelCreation(int argc, char* argv[])
{
if (argc != 6)
{
std::cerr << "Wrong number of arguments" << std::endl;
return EXIT_FAILURE;
}
const char* inputImageFileName = argv[1];
const char* inputSamplesLocation = argv[2];
const char* outputFileName = argv[3];
int streaming = atoi(argv[4]);
int neighborhood = atoi(argv[5]);
ImageReaderType::Pointer imageReader = ImageReaderType::New();
imageReader->SetFileName(inputImageFileName);
VectorDataReaderType::Pointer vectorDataReader = VectorDataReaderType::New();
vectorDataReader->SetFileName(inputSamplesLocation);
//imageReader->Update();
//vectorDataReader->Update();
AdapatedFunctionType::Pointer descriptorsFunction = AdapatedFunctionType::New();
descriptorsFunction->SetInputImage(imageReader->GetOutput());
descriptorsFunction->GetInternalImageFunction()->SetNeighborhoodRadius(neighborhood);
DescriptorsListSampleGeneratorType::Pointer descriptorsGenerator = DescriptorsListSampleGeneratorType::New();
descriptorsGenerator->SetInputImage(imageReader->GetOutput());
descriptorsGenerator->SetSamplesLocations(vectorDataReader->GetOutput());
descriptorsGenerator->SetDescriptorsFunction(descriptorsFunction.GetPointer());
descriptorsGenerator->SetNeighborhoodRadius(5);
if (streaming == 0)
{
descriptorsGenerator->GetStreamer()->SetNumberOfStreamDivisions(1);
}
else
{
descriptorsGenerator->GetStreamer()->SetNumberOfStreamDivisions(streaming);
}
descriptorsGenerator->Update();
SVMEstimatorType::Pointer svmEstimator = SVMEstimatorType::New();
svmEstimator->SetInputSampleList(descriptorsGenerator->GetListSample());
svmEstimator->SetTrainingSampleList(descriptorsGenerator->GetLabelListSample());
svmEstimator->Update();
svmEstimator->GetModel()->SaveModel(outputFileName);
return EXIT_SUCCESS;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment