Commit 4eea72f2 authored by Guillaume Pasero's avatar Guillaume Pasero

REFAC: change Classify() interface because of PredictBatch()

parent ae1db75a
......@@ -506,8 +506,6 @@ void DoExecute() ITK_OVERRIDE
// Performances estimation
//--------------------------
ListSampleType::Pointer performanceListSample;
TargetListSampleType::Pointer predictedList = TargetListSampleType::New();
predictedList->SetMeasurementVectorSize(1);
TargetListSampleType::Pointer performanceLabeledListSample;
//Test the input validation set size
......@@ -523,7 +521,8 @@ void DoExecute() ITK_OVERRIDE
performanceLabeledListSample = trainingLabeledListSample;
}
this->Classify(performanceListSample, predictedList, GetParameterString("io.out"));
TargetListSampleType::Pointer predictedList =
this->Classify(performanceListSample, GetParameterString("io.out"));
otbAppLogINFO("Training performances");
double mse=0.0;
......
......@@ -446,7 +446,6 @@ void DoExecute()
}
//Test the input validation set size
TargetListSampleType::Pointer predictedList = TargetListSampleType::New();
ListSampleType::Pointer performanceListSample;
TargetListSampleType::Pointer performanceLabeledListSample;
if(validationLabeledListSample->Size() != 0)
......@@ -461,7 +460,8 @@ void DoExecute()
performanceLabeledListSample = trainingLabeledListSample;
}
this->Classify(performanceListSample, predictedList, GetParameterString("io.out"));
TargetListSampleType::Pointer predictedList =
this->Classify(performanceListSample, GetParameterString("io.out"));
ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New();
......
......@@ -115,9 +115,9 @@ protected:
std::string modelPath);
/** Generic method to load a model file and use it to classify a sample list*/
void Classify(typename ListSampleType::Pointer validationListSample,
typename TargetListSampleType::Pointer predictedList,
std::string modelPath);
typename TargetListSampleType::Pointer Classify(
typename ListSampleType::Pointer validationListSample,
std::string modelPath);
/** Init method that creates all the parameters for machine learning models */
void DoInit();
......
......@@ -85,10 +85,10 @@ LearningApplicationBase<TInputValue,TOutputValue>
}
template <class TInputValue, class TOutputValue>
void
typename LearningApplicationBase<TInputValue,TOutputValue>
::TargetListSampleType::Pointer
LearningApplicationBase<TInputValue,TOutputValue>
::Classify(typename ListSampleType::Pointer validationListSample,
typename TargetListSampleType::Pointer predictedList,
std::string modelPath)
{
// Setup fake reporter
......@@ -110,11 +110,13 @@ LearningApplicationBase<TInputValue,TOutputValue>
model->Load(modelPath);
model->SetRegressionMode(this->m_RegressionFlag);
predictedList = model->PredictBatch(validationListSample, NULL);
typename TargetListSampleType::Pointer predictedList = model->PredictBatch(validationListSample, NULL);
// update reporter
dummyFilter->UpdateProgress(1.0f);
dummyFilter->InvokeEvent(itk::EndEvent());
return predictedList;
}
template <class TInputValue, class TOutputValue>
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment