Commit f018e336 authored by Guillaume Pasero's avatar Guillaume Pasero

BUG: Mantis-1167: add progress report to TrainImagesClassifier

parent 7a27bd19
......@@ -288,6 +288,9 @@ void DoExecute()
SampleType meanMeasurementVector;
SampleType stddevMeasurementVector;
// Setup the DEM Handler
otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this, "elev");
//--------------------------
// Load measurements from images
unsigned int nbBands = 0;
......@@ -301,6 +304,10 @@ void DoExecute()
//Iterate over all input images
for (unsigned int imgIndex = 0; imgIndex < imageList->Size(); ++imgIndex)
{
std::ostringstream oss1, oss2;
oss1 << "Reproject polygons for image " << (imgIndex+1) << " ...";
oss2 << "Extract samples from image " << (imgIndex+1) << " ...";
FloatVectorImageType::Pointer image = imageList->GetNthElement(imgIndex);
image->UpdateOutputInformation();
......@@ -310,16 +317,11 @@ void DoExecute()
}
// read the Vectordata
VectorDataType::Pointer vectorData = vectorDataList->GetNthElement(imgIndex);
vectorData->Update();
vdreproj->SetInputImage(image);
vdreproj->SetInput(vectorData);
vdreproj->SetInput(vectorDataList->GetNthElement(imgIndex));
vdreproj->SetUseOutputSpacingAndOriginFromImage(false);
// Setup the DEM Handler
otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this, "elev");
AddProcess(vdreproj, oss1.str());
vdreproj->Update();
//Sample list generator
......@@ -340,15 +342,27 @@ void DoExecute()
sampleGenerator->SetPolygonEdgeInclusion(true);
}
AddProcess(sampleGenerator, oss2.str());
sampleGenerator->Update();
TargetListSampleType::Pointer trainLabels = sampleGenerator->GetTrainingListLabel();
ListSampleType::Pointer trainSamples = sampleGenerator->GetTrainingListSample();
TargetListSampleType::Pointer validLabels = sampleGenerator->GetValidationListLabel();
ListSampleType::Pointer validSamples = sampleGenerator->GetValidationListSample();
trainLabels->DisconnectPipeline();
trainSamples->DisconnectPipeline();
validLabels->DisconnectPipeline();
validSamples->DisconnectPipeline();
//Concatenate training and validation samples from the image
concatenateTrainingLabels->AddInput(sampleGenerator->GetTrainingListLabel());
concatenateTrainingSamples->AddInput(sampleGenerator->GetTrainingListSample());
concatenateValidationLabels->AddInput(sampleGenerator->GetValidationListLabel());
concatenateValidationSamples->AddInput(sampleGenerator->GetValidationListSample());
concatenateTrainingLabels->AddInput(trainLabels);
concatenateTrainingSamples->AddInput(trainSamples);
concatenateValidationLabels->AddInput(validLabels);
concatenateValidationSamples->AddInput(validSamples);
}
// Update
AddProcess(concatenateValidationLabels, "Concatenate samples ...");
concatenateTrainingSamples->Update();
concatenateTrainingLabels->Update();
concatenateValidationSamples->Update();
......@@ -384,6 +398,7 @@ void DoExecute()
trainingShiftScaleFilter->SetInput(concatenateTrainingSamples->GetOutput());
trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
AddProcess(trainingShiftScaleFilter, "Normalize training samples ...");
trainingShiftScaleFilter->Update();
ListSampleType::Pointer validationListSample=ListSampleType::New();
......@@ -395,6 +410,7 @@ void DoExecute()
validationShiftScaleFilter->SetInput(concatenateValidationSamples->GetOutput());
validationShiftScaleFilter->SetShifts(meanMeasurementVector);
validationShiftScaleFilter->SetScales(stddevMeasurementVector);
AddProcess(validationShiftScaleFilter, "Normalize validation samples ...");
validationShiftScaleFilter->Update();
validationListSample = validationShiftScaleFilter->GetOutput();
}
......@@ -467,7 +483,6 @@ void DoExecute()
otbAppLogINFO("ValidationLabeledListSample size : " << performanceLabeledListSample->Size());
confMatCalc->SetReferenceLabels(performanceLabeledListSample);
confMatCalc->SetProducedLabels(predictedList);
confMatCalc->Compute();
otbAppLogINFO("training performances");
......
......@@ -18,6 +18,8 @@
#define __otbLearningApplicationBase_txx
#include "otbLearningApplicationBase.h"
// only need this filter as a dummy process object
#include "otbRGBAPixelConverter.h"
namespace otb
{
......@@ -73,6 +75,13 @@ LearningApplicationBase<TInputValue,TOutputValue>
typename TargetListSampleType::Pointer predictedList,
std::string modelPath)
{
// Setup fake reporter
RGBAPixelConverter<int,int>::Pointer dummyFilter =
RGBAPixelConverter<int,int>::New();
dummyFilter->SetProgress(0.0f);
this->AddProcess(dummyFilter,"Classify...");
dummyFilter->InvokeEvent(itk::StartEvent());
// load a machine learning model from file and predict the input sample list
ModelPointerType model = ModelFactoryType::CreateMachineLearningModel(modelPath,
ModelFactoryType::ReadMode);
......@@ -87,6 +96,10 @@ LearningApplicationBase<TInputValue,TOutputValue>
model->SetInputListSample(validationListSample);
model->SetTargetListSample(predictedList);
model->PredictAll();
// update reporter
dummyFilter->UpdateProgress(1.0f);
dummyFilter->InvokeEvent(itk::EndEvent());
}
template <class TInputValue, class TOutputValue>
......@@ -96,6 +109,13 @@ LearningApplicationBase<TInputValue,TOutputValue>
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath)
{
// Setup fake reporter
RGBAPixelConverter<int,int>::Pointer dummyFilter =
RGBAPixelConverter<int,int>::New();
dummyFilter->SetProgress(0.0f);
this->AddProcess(dummyFilter,"Training model...");
dummyFilter->InvokeEvent(itk::StartEvent());
// get the name of the chosen machine learning model
const std::string modelName = GetParameterString("classifier");
// call specific train function
......@@ -173,6 +193,9 @@ LearningApplicationBase<TInputValue,TOutputValue>
otbAppLogFATAL("Module OPENCV is not installed. You should consider turning OTB_USE_OPENCV on during cmake configuration.");
#endif
}
// update reporter
dummyFilter->UpdateProgress(1.0f);
dummyFilter->InvokeEvent(itk::EndEvent());
}
}
......
......@@ -235,6 +235,7 @@ VectorDataIntoImageProjectionFilter<TInputVectorData, TInputImage>
}
this->GraftOutput(m_VdProjFilter->GetOutput());
this->UpdateProgress(1.0f);
}
......
......@@ -326,6 +326,7 @@ ListSampleGenerator<TImage, TVectorData>
assert(trainingListSample->Size() == trainingListLabel->Size());
assert(validationListSample->Size() == validationListLabel->Size());
this->UpdateProgress(1.0f);
}
template <class TImage, class TVectorData>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment