diff --git a/Modules/Applications/AppClassification/app/otbTrainRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainRegression.cxx index 8c024a9088f631ddfb896b39914db1dbc9692817..0e9fc02dfea39cb3503079848110c7e639a2248d 100644 --- a/Modules/Applications/AppClassification/app/otbTrainRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainRegression.cxx @@ -506,8 +506,6 @@ void DoExecute() ITK_OVERRIDE // Performances estimation //-------------------------- ListSampleType::Pointer performanceListSample; - TargetListSampleType::Pointer predictedList = TargetListSampleType::New(); - predictedList->SetMeasurementVectorSize(1); TargetListSampleType::Pointer performanceLabeledListSample; //Test the input validation set size @@ -523,7 +521,8 @@ void DoExecute() ITK_OVERRIDE performanceLabeledListSample = trainingLabeledListSample; } - this->Classify(performanceListSample, predictedList, GetParameterString("io.out")); + TargetListSampleType::Pointer predictedList = + this->Classify(performanceListSample, GetParameterString("io.out")); otbAppLogINFO("Training performances"); double mse=0.0; diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index f9acb6fbe0cafd499c68c234d3a976fa223cd47a..991ea469e4830d10c24821e09f7cd3923f442e02 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -446,7 +446,6 @@ void DoExecute() } //Test the input validation set size - TargetListSampleType::Pointer predictedList = TargetListSampleType::New(); ListSampleType::Pointer performanceListSample; TargetListSampleType::Pointer performanceLabeledListSample; if(validationLabeledListSample->Size() != 0) @@ -461,7 +460,8 @@ void DoExecute() performanceLabeledListSample = trainingLabeledListSample; } - this->Classify(performanceListSample, predictedList, GetParameterString("io.out")); + TargetListSampleType::Pointer predictedList = + this->Classify(performanceListSample, GetParameterString("io.out")); ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New(); diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h index 79767a1c02509c581fe62e8fe5cb25eba89e19ba..c1e6e37ba76dfd2c5bffa87de40d669de16764aa 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.h @@ -115,9 +115,9 @@ protected: std::string modelPath); /** Generic method to load a model file and use it to classify a sample list*/ - void Classify(typename ListSampleType::Pointer validationListSample, - typename TargetListSampleType::Pointer predictedList, - std::string modelPath); + typename TargetListSampleType::Pointer Classify( + typename ListSampleType::Pointer validationListSample, + std::string modelPath); /** Init method that creates all the parameters for machine learning models */ void DoInit(); diff --git a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx index cd0e243dc4d1eb43412d19cf01b0090e98ae694d..097ed44c959d5aa167c3e04e62ef40afd7d4a19d 100644 --- a/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx +++ b/Modules/Applications/AppClassification/include/otbLearningApplicationBase.txx @@ -85,10 +85,10 @@ LearningApplicationBase<TInputValue,TOutputValue> } template <class TInputValue, class TOutputValue> -void +typename LearningApplicationBase<TInputValue,TOutputValue> +::TargetListSampleType::Pointer LearningApplicationBase<TInputValue,TOutputValue> ::Classify(typename ListSampleType::Pointer validationListSample, - typename TargetListSampleType::Pointer predictedList, std::string modelPath) { // Setup fake reporter @@ -110,11 +110,13 @@ LearningApplicationBase<TInputValue,TOutputValue> model->Load(modelPath); model->SetRegressionMode(this->m_RegressionFlag); - predictedList = model->PredictBatch(validationListSample, NULL); + typename TargetListSampleType::Pointer predictedList = model->PredictBatch(validationListSample, NULL); // update reporter dummyFilter->UpdateProgress(1.0f); dummyFilter->InvokeEvent(itk::EndEvent()); + + return predictedList; } template <class TInputValue, class TOutputValue>