From 0d8f47f8746f6a176b7e130206b8c53127d4c268 Mon Sep 17 00:00:00 2001 From: Jonathan Guinet <jonathan.guinet@c-s.fr> Date: Fri, 4 Nov 2011 12:38:27 +0100 Subject: [PATCH] ENH: Train SVM Application change. --- .../otbTrainSVMImagesClassifier.cxx | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/Applications/Classification/otbTrainSVMImagesClassifier.cxx b/Applications/Classification/otbTrainSVMImagesClassifier.cxx index 67c36d0cfe..99e6f9d781 100644 --- a/Applications/Classification/otbTrainSVMImagesClassifier.cxx +++ b/Applications/Classification/otbTrainSVMImagesClassifier.cxx @@ -165,16 +165,13 @@ private: AddParameter(ParameterType_Filename, "imstat", "XML image statistics file"); MandatoryOff("imstat"); SetParameterDescription("imstat", "filename of an XML file containing mean and standard deviation of input images."); - AddParameter(ParameterType_Filename, "out", "Output SVM model"); - SetParameterDescription("out", "Output SVM model"); AddParameter(ParameterType_Float, "m", "Margin for SVM learning"); - MandatoryOff("m"); - SetParameterDescription("m", "Margin for SVM learning."); + SetParameterFloat("m", 1.0); + SetParameterDescription("m", "Margin for SVM learning.(1 by default)."); AddParameter(ParameterType_Int, "b", "Balance and grow the training set"); SetParameterDescription("b", "Balance and grow the training set."); MandatoryOff("b"); AddParameter(ParameterType_Choice, "k", "SVM Kernel Type"); - MandatoryOff("k"); AddChoice("k.linear", "Linear"); AddChoice("k.rbf", "Neareast Neighbor"); AddChoice("k.poly", "Polynomial"); @@ -182,25 +179,26 @@ private: SetParameterString("k", "linear"); SetParameterDescription("k", "SVM Kernel Type."); AddParameter(ParameterType_Int, "mt", "Maximum training sample size"); - MandatoryOff("mt"); + //MandatoryOff("mt"); SetDefaultParameterInt("mt", -1); SetParameterDescription("mt", "Maximum size of the training sample (default = -1)."); AddParameter(ParameterType_Int, "mv", "Maximum validation sample size"); - MandatoryOff("mv"); + // MandatoryOff("mv"); SetDefaultParameterInt("mv", -1); SetParameterDescription("mv", "Maximum size of the validation sample (default = -1)"); AddParameter(ParameterType_Float, "vtr", "training and validation sample ratio"); SetParameterDescription("vtr", "Ratio between training and validation sample (0.0 = all training, 1.0 = all validation) default = 0.5."); - MandatoryOff("vtr"); SetParameterFloat("vtr", 0.5); AddParameter(ParameterType_Empty, "opt", "parameters optimization"); MandatoryOff("opt"); SetParameterDescription("opt", "SVM parameters optimization"); AddParameter(ParameterType_Filename, "vfn", "Name of the discrimination field"); - MandatoryOff("vfn"); SetParameterDescription("vfn", "Name of the field using to discriminate class in the vector data files."); SetParameterString("vfn", "Class"); + AddParameter(ParameterType_Filename, "out", "Output SVM model"); + SetParameterDescription("out", "Output SVM model"); + SetParameterRole("out",Role_Output); } @@ -252,7 +250,7 @@ private: vdreproj->SetUseOutputSpacingAndOriginFromImage(false); // Configure DEM directory - if (HasUserValue("dem")) + if (IsParameterEnabled("dem")) { vdreproj->SetDEMDirectory(GetParameterString("dem")); } @@ -274,14 +272,11 @@ private: sampleGenerator->SetInput(image); sampleGenerator->SetInputVectorData(vdreproj->GetOutput()); - if (HasUserValue("vfn")) - { - sampleGenerator->SetClassKey(GetParameterString("vfn")); - } - + sampleGenerator->SetClassKey(GetParameterString("vfn")); sampleGenerator->SetMaxTrainingSize(GetParameterInt("mt")); sampleGenerator->SetMaxValidationSize(GetParameterInt("mv")); sampleGenerator->SetValidationTrainingProportion(GetParameterFloat("vtr")); + sampleGenerator->Update(); //Concatenate training and validation samples from the image @@ -296,7 +291,7 @@ private: concatenateValidationSamples->Update(); concatenateValidationLabels->Update(); - if (HasValue("imstat")) + if (IsParameterEnabled("imstat")) { StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); statisticsReader->SetFileName(GetParameterString("imstat")); @@ -328,7 +323,7 @@ private: LabelListSampleType::Pointer labelListSample; //-------------------------- // Balancing training sample (if needed) - if (HasUserValue("b")) + if (IsParameterEnabled("b")) { // Balance the list sample. otbAppLogINFO("Number of training samples before balancing: " << concatenateTrainingSamples->GetOutputSampleList()->Size()) @@ -371,10 +366,9 @@ private: svmestimator->SetParametersOptimization(true); } - if (HasUserValue("m")) - { - svmestimator->SetC(GetParameterFloat("m")); - } + + svmestimator->SetC(GetParameterFloat("m")); + switch (GetParameterInt("k")) { -- GitLab