Skip to content
Snippets Groups Projects
Commit c28535a5 authored by Cyrille Valladeau's avatar Cyrille Valladeau
Browse files

Embellissement des test EstimatorTrain.

parent ecc36b42
No related branches found
No related tags found
No related merge requests found
......@@ -241,6 +241,9 @@ ADD_TEST(leTvSVMModelCopyComposedKernel ${LEARNING_TESTS3}
${INPUTDATA}/svm_model_composed
${TEMP}/svmcopycomposed_test)
ADD_TEST(teststephanie ${LEARNING_TESTS3}
teststephanie)
# A enrichir
SET(BasicLearning_SRCS1
......@@ -279,6 +282,7 @@ otbSVMModelGenericKernelsTest.cxx
otbSVMModelCopyTest.cxx
otbSVMModelCopyGenericKernelTest.cxx
otbSVMModelCopyComposedKernelTest.cxx
teststephanie.cxx
)
......
......@@ -37,4 +37,5 @@ REGISTER_TEST(otbSVMModelGenericKernelsTest);
REGISTER_TEST(otbSVMModelCopyTest);
REGISTER_TEST(otbSVMModelCopyGenericKernelTest);
REGISTER_TEST(otbSVMModelCopyComposedKernelTest);
REGISTER_TEST(teststephanie);
}
......@@ -33,67 +33,39 @@
int otbSVMImageModelEstimatorTrain( int argc, char* argv[] )
{
try
{
const char* inputImageFileName = argv[1];
const char* trainingImageFileName = argv[2];
const char* outputModelFileName = argv[3];
typedef double InputPixelType;
const unsigned int Dimension = 2;
typedef otb::VectorImage< InputPixelType, Dimension > InputImageType;
typedef otb::Image< int, Dimension > TrainingImageType;
typedef std::vector<double> VectorType;
typedef otb::SVMImageModelEstimator< InputImageType,
TrainingImageType > EstimatorType;
typedef otb::ImageFileReader< InputImageType > InputReaderType;
typedef otb::ImageFileReader< TrainingImageType > TrainingReaderType;
InputReaderType::Pointer inputReader = InputReaderType::New();
TrainingReaderType::Pointer trainingReader = TrainingReaderType::New();
inputReader->SetFileName( inputImageFileName );
trainingReader->SetFileName( trainingImageFileName );
inputReader->Update();
trainingReader->Update();
EstimatorType::Pointer svmEstimator = EstimatorType::New();
svmEstimator->SetInputImage( inputReader->GetOutput() );
svmEstimator->SetTrainingImage( trainingReader->GetOutput() );
svmEstimator->SetNumberOfClasses( 2 );
svmEstimator->Update();
std::cout << "Saving model" << std::endl;
svmEstimator->SaveModel(outputModelFileName);
}
catch( itk::ExceptionObject & err )
{
std::cout << "Exception itk::ExceptionObject levee !" << std::endl;
std::cout << err << std::endl;
return EXIT_FAILURE;
}
catch( ... )
{
std::cout << "Unknown exception !" << std::endl;
return EXIT_FAILURE;
}
// Software Guide : EndCodeSnippet
//#endif
const char* inputImageFileName = argv[1];
const char* trainingImageFileName = argv[2];
const char* outputModelFileName = argv[3];
typedef double InputPixelType;
const unsigned int Dimension = 2;
typedef otb::VectorImage< InputPixelType, Dimension > InputImageType;
typedef otb::Image< int, Dimension > TrainingImageType;
typedef std::vector<double> VectorType;
typedef otb::SVMImageModelEstimator< InputImageType,
TrainingImageType > EstimatorType;
typedef otb::ImageFileReader< InputImageType > InputReaderType;
typedef otb::ImageFileReader< TrainingImageType > TrainingReaderType;
InputReaderType::Pointer inputReader = InputReaderType::New();
TrainingReaderType::Pointer trainingReader = TrainingReaderType::New();
EstimatorType::Pointer svmEstimator = EstimatorType::New();
inputReader->SetFileName( inputImageFileName );
trainingReader->SetFileName( trainingImageFileName );
inputReader->Update();
trainingReader->Update();
svmEstimator->SetInputImage( inputReader->GetOutput() );
svmEstimator->SetTrainingImage( trainingReader->GetOutput() );
svmEstimator->SetNumberOfClasses( 2 );
svmEstimator->Update();
itkGenericExceptionMacro(<<"Saving model");
svmEstimator->SaveModel(outputModelFileName);
return EXIT_SUCCESS;
}
......
......@@ -30,102 +30,64 @@
int otbSVMPointSetModelEstimatorTrain( int argc, char* argv[] )
{
try
const char* outputModelFileName = argv[1];
typedef std::vector<double> InputPixelType;
typedef double LabelPixelType;
const unsigned int Dimension = 2;
typedef itk::PointSet< InputPixelType, Dimension > MeasurePointSetType;
typedef itk::PointSet< LabelPixelType, Dimension > LabelPointSetType;
typedef MeasurePointSetType::PointType MeasurePointType;
typedef LabelPointSetType::PointType LabelPointType;
typedef MeasurePointSetType::PointsContainer MeasurePointsContainer;
typedef LabelPointSetType::PointsContainer LabelPointsContainer;
MeasurePointSetType::Pointer mPSet = MeasurePointSetType::New();
LabelPointSetType::Pointer lPSet = LabelPointSetType::New();
MeasurePointsContainer::Pointer mCont = MeasurePointsContainer::New();
LabelPointsContainer::Pointer lCont = LabelPointsContainer::New();
for(unsigned int pointId = 0; pointId<20; pointId++)
{
const char* outputModelFileName = argv[1];
typedef std::vector<double> InputPixelType;
typedef double LabelPixelType;
const unsigned int Dimension = 2;
typedef itk::PointSet< InputPixelType, Dimension >
MeasurePointSetType;
typedef itk::PointSet< LabelPixelType, Dimension > LabelPointSetType;
MeasurePointSetType::Pointer mPSet = MeasurePointSetType::New();
LabelPointSetType::Pointer lPSet = LabelPointSetType::New();
typedef MeasurePointSetType::PointType MeasurePointType;
typedef LabelPointSetType::PointType LabelPointType;
typedef MeasurePointSetType::PointsContainer MeasurePointsContainer;
typedef LabelPointSetType::PointsContainer LabelPointsContainer;
MeasurePointsContainer::Pointer mCont = MeasurePointsContainer::New();
LabelPointsContainer::Pointer lCont = LabelPointsContainer::New();
for(unsigned int pointId = 0; pointId<20; pointId++)
{
MeasurePointType mP;
LabelPointType lP;
mP[0] = pointId;
mP[1] = pointId;
lP[0] = pointId;
lP[1] = pointId;
InputPixelType measure;
// measure.push_back(vcl_pow(pointId,2.0));
measure.push_back(double(2.0*pointId));
measure.push_back(double(-10));
LabelPixelType label = static_cast<LabelPixelType>(
(measure[0]+measure[1])>0); //2x-10>0
std::cout << "Label : " << label << std::endl;
LabelPixelType label = static_cast<LabelPixelType>( (measure[0]+measure[1])>0 ); //2x-10>0
mCont->InsertElement( pointId , mP );
mPSet->SetPointData( pointId, measure );
mPSet->SetPointData( pointId, measure );
lCont->InsertElement( pointId , lP );
lPSet->SetPointData( pointId, label );
}
mPSet->SetPoints( mCont );
lPSet->SetPoints( lCont );
typedef otb::SVMPointSetModelEstimator< MeasurePointSetType,
LabelPointSetType > EstimatorType;
EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetInputPointSet( mPSet );
estimator->SetTrainingPointSet( lPSet );
estimator->SetNumberOfClasses( 2 );
estimator->Update();
std::cout << "Saving model" << std::endl;
estimator->SaveModel(outputModelFileName);
}
catch( itk::ExceptionObject & err )
{
std::cout << "Exception itk::ExceptionObject levee !" << std::endl;
std::cout << err << std::endl;
return EXIT_FAILURE;
}
catch( ... )
{
std::cout << "Unknown exception !" << std::endl;
return EXIT_FAILURE;
}
// Software Guide : EndCodeSnippet
//#endif
}
mPSet->SetPoints( mCont );
lPSet->SetPoints( lCont );
typedef otb::SVMPointSetModelEstimator< MeasurePointSetType, LabelPointSetType > EstimatorType;
EstimatorType::Pointer estimator = EstimatorType::New();
estimator->SetInputPointSet( mPSet );
estimator->SetTrainingPointSet( lPSet );
estimator->SetNumberOfClasses( 2 );
estimator->Update();
estimator->SaveModel(outputModelFileName);
return EXIT_SUCCESS;
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment