diff --git a/Applications/Classification/otbTrainSVMImagesClassifier.cxx b/Applications/Classification/otbTrainSVMImagesClassifier.cxx index b4deedff4fda70b6275733eb6da775e30ae9294f..ef2d5ae41ade02e058c4a1c62f909cd1d0f951d9 100644 --- a/Applications/Classification/otbTrainSVMImagesClassifier.cxx +++ b/Applications/Classification/otbTrainSVMImagesClassifier.cxx @@ -1,5 +1,4 @@ /*========================================================================= - Program: ORFEO Toolbox Language: C++ Date: $Date$ @@ -170,29 +169,9 @@ private: SetParameterDescription("io.out", "Output SVM model"); SetParameterRole("io.out", Role_Output); - //Group SVM - AddParameter(ParameterType_Group,"svm","SVM classifier parameters"); - SetParameterDescription("svm","This group of parameters allows to set SVM classifier parameters."); - AddParameter(ParameterType_Choice, "svm.k", "SVM Kernel Type"); - AddChoice("svm.k.linear", "Linear"); - AddChoice("svm.k.rbf", "Neareast Neighbor"); - AddChoice("svm.k.poly", "Polynomial"); - AddChoice("svm.k.sigmoid", "Sigmoid"); - SetParameterString("svm.k", "linear"); - SetParameterDescription("svm.k", "SVM Kernel Type."); - AddParameter(ParameterType_Float, "svm.m", "Margin for SVM learning"); - SetParameterFloat("svm.m", 1.0); - SetParameterDescription("svm.m", "Margin for SVM learning.(1 by default)."); - AddParameter(ParameterType_Empty, "svm.opt", "parameters optimization"); - MandatoryOff("svm.opt"); - SetParameterDescription("svm.opt", "SVM parameters optimization"); - //Group Sample list AddParameter(ParameterType_Group,"sample","Training and validation samples parameters"); - SetParameterDescription("svm","This group of parameters allows to set training and validation sample lists parameters."); - AddParameter(ParameterType_Int, "sample.b", "Balance and grow the training set"); - SetParameterDescription("sample.b", "Balance and grow the training set."); - MandatoryOff("sample.b"); + SetParameterDescription("sample","This group of parameters allows to set training and validation sample lists parameters."); AddParameter(ParameterType_Int, "sample.mt", "Maximum training sample size"); //MandatoryOff("mt"); @@ -202,6 +181,11 @@ private: // MandatoryOff("mv"); SetDefaultParameterInt("sample.mv", -1); SetParameterDescription("sample.mv", "Maximum size of the validation sample (default = -1)"); + + // AddParameter(ParameterType_Int, "sample.b", "Balance and grow the training set"); + // SetParameterDescription("sample.b", "Balance and grow the training set."); + // MandatoryOff("sample.b"); + AddParameter(ParameterType_Float, "sample.vtr", "training and validation sample ratio"); SetParameterDescription("sample.vtr", "Ratio between training and validation sample (0.0 = all training, 1.0 = all validation) default = 0.5."); @@ -210,6 +194,23 @@ private: AddParameter(ParameterType_Filename, "sample.vfn", "Name of the discrimination field"); SetParameterDescription("sample.vfn", "Name of the field using to discriminate class in the vector data files."); SetParameterString("sample.vfn", "Class"); + + //Group SVM + AddParameter(ParameterType_Group,"svm","SVM classifier parameters"); + SetParameterDescription("svm","This group of parameters allows to set SVM classifier parameters."); + AddParameter(ParameterType_Choice, "svm.k", "SVM Kernel Type"); + AddChoice("svm.k.linear", "Linear"); + AddChoice("svm.k.rbf", "Neareast Neighbor"); + AddChoice("svm.k.poly", "Polynomial"); + AddChoice("svm.k.sigmoid", "Sigmoid"); + SetParameterString("svm.k", "linear"); + SetParameterDescription("svm.k", "SVM Kernel Type."); + AddParameter(ParameterType_Float, "svm.m", "Margin for SVM learning"); + SetParameterFloat("svm.m", 1.0); + SetParameterDescription("svm.m", "Margin for SVM learning.(1 by default)."); + AddParameter(ParameterType_Empty, "svm.opt", "parameters optimization"); + MandatoryOff("svm.opt"); + SetParameterDescription("svm.opt", "SVM parameters optimization"); } void DoUpdateParameters() @@ -333,26 +334,26 @@ private: LabelListSampleType::Pointer labelListSample; //-------------------------- // Balancing training sample (if needed) - if (IsParameterEnabled("sample.b")) - { - // Balance the list sample. - otbAppLogINFO("Number of training samples before balancing: " << concatenateTrainingSamples->GetOutputSampleList()->Size()) - BalancingListSampleFilterType::Pointer balancingFilter = BalancingListSampleFilterType::New(); - balancingFilter->SetInput(trainingShiftScaleFilter->GetOutput()/*GetOutputSampleList()*/); - balancingFilter->SetInputLabel(concatenateTrainingLabels->GetOutput()/*GetOutputSampleList()*/); - balancingFilter->SetBalancingFactor(GetParameterInt("sample.b")); - balancingFilter->Update(); - listSample = balancingFilter->GetOutputSampleList(); - labelListSample = balancingFilter->GetOutputLabelSampleList(); - otbAppLogINFO("Number of samples after balancing: " << balancingFilter->GetOutputSampleList()->Size()); - - } - else - { + // if (IsParameterEnabled("sample.b")) + // { + // // Balance the list sample. + // otbAppLogINFO("Number of training samples before balancing: " << concatenateTrainingSamples->GetOutputSampleList()->Size()) + // BalancingListSampleFilterType::Pointer balancingFilter = BalancingListSampleFilterType::New(); + // balancingFilter->SetInput(trainingShiftScaleFilter->GetOutput()/*GetOutputSampleList()*/); + // balancingFilter->SetInputLabel(concatenateTrainingLabels->GetOutput()/*GetOutputSampleList()*/); + // balancingFilter->SetBalancingFactor(GetParameterInt("sample.b")); + // balancingFilter->Update(); + // listSample = balancingFilter->GetOutputSampleList(); + // labelListSample = balancingFilter->GetOutputLabelSampleList(); + // otbAppLogINFO("Number of samples after balancing: " << balancingFilter->GetOutputSampleList()->Size()); + + // } + // else + // { listSample = trainingShiftScaleFilter->GetOutputSampleList(); labelListSample = concatenateTrainingLabels->GetOutputSampleList(); otbAppLogINFO("Number of training samples: " << concatenateTrainingSamples->GetOutputSampleList()->Size()); - } + // } //-------------------------- // Split the data set into training/validation set ListSampleType::Pointer trainingListSample = listSample; diff --git a/Testing/Applications/Classification/CMakeLists.txt b/Testing/Applications/Classification/CMakeLists.txt index 702338eccbbc41e7e3b5a55335b043a4b7a2ee39..a51f1e9e7b7d4a5fcf1a81205289df69512f36e7 100644 --- a/Testing/Applications/Classification/CMakeLists.txt +++ b/Testing/Applications/Classification/CMakeLists.txt @@ -16,7 +16,7 @@ OTB_TEST_APPLICATION(NAME apTvClTrainSVMImagesClassifierQB1 OPTIONS --io.il ${INPUTDATA}/Classification/QB_1_ortho.tif --io.vd ${INPUTDATA}/Classification/VectorData_QB1.shp --io.imstat ${TEMP}/apTvClEstimateImageStatisticsQB1.xml - --sample.b 2 + ##--sample.b 2 --svm.opt true --io.out ${TEMP}/clsvmModelQB1.svm VALID --compare-ascii ${NOTOL} @@ -30,7 +30,7 @@ OTB_TEST_APPLICATION(NAME apTvClTrainSVMImagesClassifierQB1_allOpt OPTIONS --io.il ${INPUTDATA}/Classification/QB_1_ortho.tif --io.vd ${INPUTDATA}/Classification/VectorData_QB1.shp --io.imstat ${TEMP}/apTvClEstimateImageStatisticsQB1.xml - --sample.b 2 + ##--sample.b 2 --sample.mv 100 --sample.mt 100 --sample.vtr 0.5 @@ -116,7 +116,7 @@ OTB_TEST_APPLICATION(NAME apTvClTrainSVMImagesClassifierQB123 ${INPUTDATA}/Classification/VectorData_QB2.shp ${INPUTDATA}/Classification/VectorData_QB3.shp --io.imstat ${TEMP}/apTvClEstimateImageStatisticsQB123.xml - --sample.b 2 + #--sample.b 2 --svm.opt true --io.out ${TEMP}/clsvmModelQB123.svm VALID --compare-ascii ${NOTOL} @@ -221,7 +221,7 @@ OTB_TEST_APPLICATION(NAME apTvClTrainSVMImagesClassifierQB456 ${INPUTDATA}/Classification/VectorData_QB5.shp ${INPUTDATA}/Classification/VectorData_QB6.shp --io.imstat ${TEMP}/apTvClEstimateImageStatisticsQB456.xml - --sample.b 2 + #--sample.b 2 --svm.opt true --io.out ${TEMP}/clsvmModelQB456.svm VALID --compare-ascii ${NOTOL}