From 85fc4c8c973a840c76613cf918a9de555edf6025 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Traizet?= <cedric.traizet@c-s.fr> Date: Tue, 30 Jul 2019 16:45:20 +0200 Subject: [PATCH] ENH: partial specialization for the DoInit method --- .../app/otbVectorClassifier.cxx | 73 ++++++++++++++++++ .../app/otbVectorRegression.cxx | 63 ++++++++++++++++ .../include/otbVectorPrediction.h | 75 ++----------------- 3 files changed, 142 insertions(+), 69 deletions(-) diff --git a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx index da553cfae0..62d6800210 100644 --- a/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbVectorClassifier.cxx @@ -27,6 +27,79 @@ namespace Wrapper using VectorClassifier = VectorPrediction<false, float, unsigned int>; +template<> +void +VectorClassifier +::DoInitSpecialization() +{ + SetName("VectorClassifier"); + SetDescription("Performs a classification of the input vector data according to a model file."); + + SetDocAuthors("OTB-Team"); + SetDocLongDescription("This application performs a vector data classification " + "based on a model file produced by the TrainVectorClassifier application." + "Features of the vector data output will contain the class labels decided by the classifier " + "(maximal class label = 65535). \n" + "There are two modes: \n" + "1) Update mode: add of the 'cfield' field containing the predicted class in the input file. \n" + "2) Write mode: copies the existing fields of the input file to the output file " + " and add the 'cfield' field containing the predicted class. \n" + "If you have declared the output file, the write mode applies. " + "Otherwise, the input file update mode will be applied."); + + SetDocLimitations("Shapefiles are supported, but the SQLite format is only supported in update mode."); + SetDocSeeAlso("TrainVectorClassifier"); + AddDocTag(Tags::Learning); + + AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data"); + SetParameterDescription("in","The input vector data file to classify."); + + AddParameter(ParameterType_InputFilename, "instat", "Statistics file"); + SetParameterDescription("instat", "A XML file containing mean and standard deviation to center" + "and reduce samples before classification, produced by ComputeImagesStatistics application."); + MandatoryOff("instat"); + + AddParameter(ParameterType_InputFilename, "model", "Model file"); + SetParameterDescription("model", "Model file produced by TrainVectorClassifier application."); + + AddParameter(ParameterType_String,"cfield","Output field"); + SetParameterDescription("cfield","Field containing the predicted class." + "Only geometries with this field available will be taken into account.\n" + "The field is added either in the input file (if 'out' off) or in the output file.\n" + "Caution, the 'cfield' must not exist in the input file if you are updating the file."); + SetParameterString("cfield","predicted"); + + AddParameter(ParameterType_ListView, "feat", "Field names to be calculated"); + SetParameterDescription("feat","List of field names in the input vector data used as features for training. " + "Put the same field names as the TrainVectorClassifier application."); + + AddParameter(ParameterType_Bool, "confmap", "Confidence map"); + SetParameterDescription( "confmap", "Confidence map of the produced classification. The confidence index depends on the model: \n\n" + "* LibSVM: difference between the two highest probabilities (needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n" + "* Boost: sum of votes\n" + "* DecisionTree: (not supported)\n" + "* KNearestNeighbors: number of neighbors with the same label\n" + "* NeuralNetwork: difference between the two highest responses\n" + "* NormalBayes: (not supported)\n" + "* RandomForest: Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n" + "* SVM: distance to margin (only works for 2-class models)\n"); + + AddParameter(ParameterType_OutputFilename, "out", "Output vector data file"); + MandatoryOff("out"); + SetParameterDescription("out","Output vector data file storing sample values (OGR format)." + "If not given, the input vector data file is updated."); + + // Doc example parameter settings + SetDocExampleParameterValue("in", "vectorData.shp"); + SetDocExampleParameterValue("instat", "meanVar.xml"); + SetDocExampleParameterValue("model", "svmModel.svm"); + SetDocExampleParameterValue("out", "vectorDataLabeledVector.shp"); + SetDocExampleParameterValue("feat", "perimeter area width"); + SetDocExampleParameterValue("cfield", "predicted"); + + SetOfficialDocLink(); +} + template<> void VectorClassifier diff --git a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx index bba74a3152..02dd347096 100644 --- a/Modules/Applications/AppClassification/app/otbVectorRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbVectorRegression.cxx @@ -27,6 +27,69 @@ namespace Wrapper using VectorRegression = VectorPrediction<true, float, float>; +template<> +void +VectorRegression +::DoInitSpecialization() +{ + SetName("VectorRegression"); + SetDescription("Performs regression on the input vector data according to a model file."); + + SetDocAuthors("OTB-Team"); + SetDocLongDescription("This application performs a vector data regression " + "based on a model file produced by the TrainVectorRegression application." + "Features of the vector data output will contain the values predicted by the classifier. \n" + "There are two modes: \n" + "1) Update mode: add of the 'cfield' field containing the predicted value in the input file. \n" + "2) Write mode: copies the existing fields of the input file to the output file " + " and add the 'cfield' field containing the predicted value. \n" + "If you have declared the output file, the write mode applies. " + "Otherwise, the input file update mode will be applied."); + + SetDocLimitations("Shapefiles are supported, but the SQLite format is only supported in update mode."); + SetDocSeeAlso("TrainVectorRegression"); + AddDocTag(Tags::Learning); + + AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data"); + SetParameterDescription("in","The input vector data file to classify."); + + AddParameter(ParameterType_InputFilename, "instat", "Statistics file"); + SetParameterDescription("instat", "A XML file containing mean and standard deviation to center" + "and reduce samples before classification, produced by ComputeImagesStatistics application."); + MandatoryOff("instat"); + + AddParameter(ParameterType_InputFilename, "model", "Model file"); + SetParameterDescription("model", "Model file produced by TrainVectorRegression application."); + + AddParameter(ParameterType_String,"cfield","Output field"); + SetParameterDescription("cfield","Field containing the predicted value." + "Only geometries with this field available will be taken into account.\n" + "The field is added either in the input file (if 'out' off) or in the output file.\n" + "Caution, the 'cfield' must not exist in the input file if you are updating the file."); + SetParameterString("cfield","predicted"); + + AddParameter(ParameterType_ListView, "feat", "Field names to be calculated"); + SetParameterDescription("feat","List of field names in the input vector data used as features for training. " + "Put the same field names as the TrainVectorRegression application."); + + AddParameter(ParameterType_OutputFilename, "out", "Output vector data file"); + MandatoryOff("out"); + + SetParameterDescription("out","Output vector data file storing sample values (OGR format)." + "If not given, the input vector data file is updated."); + MandatoryOff("out"); + + // Doc example parameter settings + SetDocExampleParameterValue("in", "vectorData.shp"); + SetDocExampleParameterValue("instat", "meanVar.xml"); + SetDocExampleParameterValue("model", "rfModel.rf"); + SetDocExampleParameterValue("out", "vectorDataLabeledVector.shp"); + SetDocExampleParameterValue("feat", "perimeter area width"); + SetDocExampleParameterValue("cfield", "predicted"); + + SetOfficialDocLink(); +} + template<> void VectorRegression diff --git a/Modules/Applications/AppClassification/include/otbVectorPrediction.h b/Modules/Applications/AppClassification/include/otbVectorPrediction.h index dfa537be48..247a2e4ace 100644 --- a/Modules/Applications/AppClassification/include/otbVectorPrediction.h +++ b/Modules/Applications/AppClassification/include/otbVectorPrediction.h @@ -92,74 +92,12 @@ private: void DoInit() override { - SetName("VectorClassifier"); - SetDescription("Performs a classification of the input vector data according to a model file."); - - SetDocAuthors("OTB-Team"); - SetDocLongDescription("This application performs a vector data classification " - "based on a model file produced by the TrainVectorClassifier application." - "Features of the vector data output will contain the class labels decided by the classifier " - "(maximal class label = 65535). \n" - "There are two modes: \n" - "1) Update mode: add of the 'cfield' field containing the predicted class in the input file. \n" - "2) Write mode: copies the existing fields of the input file to the output file " - " and add the 'cfield' field containing the predicted class. \n" - "If you have declared the output file, the write mode applies. " - "Otherwise, the input file update mode will be applied."); - - SetDocLimitations("Shapefiles are supported, but the SQLite format is only supported in update mode."); - SetDocSeeAlso("TrainVectorClassifier"); - AddDocTag(Tags::Learning); - - AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data"); - SetParameterDescription("in","The input vector data file to classify."); - - AddParameter(ParameterType_InputFilename, "instat", "Statistics file"); - SetParameterDescription("instat", "A XML file containing mean and standard deviation to center" - "and reduce samples before classification, produced by ComputeImagesStatistics application."); - MandatoryOff("instat"); - - AddParameter(ParameterType_InputFilename, "model", "Model file"); - SetParameterDescription("model", "Model file produced by TrainVectorClassifier application."); - - AddParameter(ParameterType_String,"cfield","Field class"); - SetParameterDescription("cfield","Field containing the predicted class." - "Only geometries with this field available will be taken into account.\n" - "The field is added either in the input file (if 'out' off) or in the output file.\n" - "Caution, the 'cfield' must not exist in the input file if you are updating the file."); - SetParameterString("cfield","predicted"); - - AddParameter(ParameterType_ListView, "feat", "Field names to be calculated"); - SetParameterDescription("feat","List of field names in the input vector data used as features for training. " - "Put the same field names as the TrainVectorClassifier application."); - - AddParameter(ParameterType_Bool, "confmap", "Confidence map"); - SetParameterDescription( "confmap", "Confidence map of the produced classification. The confidence index depends on the model: \n\n" - "* LibSVM: difference between the two highest probabilities (needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n" - "* Boost: sum of votes\n" - "* DecisionTree: (not supported)\n" - "* KNearestNeighbors: number of neighbors with the same label\n" - "* NeuralNetwork: difference between the two highest responses\n" - "* NormalBayes: (not supported)\n" - "* RandomForest: Confidence (proportion of votes for the majority class). Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n" - "* SVM: distance to margin (only works for 2-class models)\n"); - - AddParameter(ParameterType_OutputFilename, "out", "Output vector data file containing class labels"); - SetParameterDescription("out","Output vector data file storing sample values (OGR format)." - "If not given, the input vector data file is updated."); - MandatoryOff("out"); - - // Doc example parameter settings - SetDocExampleParameterValue("in", "vectorData.shp"); - SetDocExampleParameterValue("instat", "meanVar.xml"); - SetDocExampleParameterValue("model", "svmModel.svm"); - SetDocExampleParameterValue("out", "vectorDataLabeledVector.shp"); - SetDocExampleParameterValue("feat", "perimeter area width"); - SetDocExampleParameterValue("cfield", "predicted"); - - SetOfficialDocLink(); + DoInitSpecialization(); + //TODO add assert to check that parameters has been correctly defined } + void DoInitSpecialization(); + void DoUpdateParameters() override { if ( HasValue("in") ) @@ -288,10 +226,9 @@ private: typename ConfidenceListSampleType::Pointer quality; - bool computeConfidenceMap(GetParameterInt("confmap") && m_Model->HasConfidenceIndex() - && !m_Model->GetRegressionMode()); + bool computeConfidenceMap(!m_Model->GetRegressionMode() && GetParameterInt("confmap") && m_Model->HasConfidenceIndex() ); - if (!m_Model->HasConfidenceIndex() && GetParameterInt("confmap")) + if (!m_Model->GetRegressionMode() && !m_Model->HasConfidenceIndex() && GetParameterInt("confmap")) { otbAppLogWARNING("Confidence map requested but the classifier doesn't support it!"); } -- GitLab