From ea7f1bcf2402a33958a3ab954cc62f1df334733a Mon Sep 17 00:00:00 2001 From: Cyrille Valladeau <cyrille.valladeau@c-s.fr> Date: Thu, 6 Oct 2011 14:51:16 +0200 Subject: [PATCH] ENH: correct ImageSVMClassifier appli --- Applications/Classification/CMakeLists.txt | 2 +- .../Classification/otbImageSVMClassifier.cxx | 48 +++++++++++-------- .../otbWrapperApplication.cxx | 16 +++++++ .../ApplicationEngine/otbWrapperApplication.h | 9 ++++ .../otbWrapperOutputImageParameter.cxx | 9 +++- 5 files changed, 62 insertions(+), 22 deletions(-) diff --git a/Applications/Classification/CMakeLists.txt b/Applications/Classification/CMakeLists.txt index b03faab27a..e5203b9ab8 100644 --- a/Applications/Classification/CMakeLists.txt +++ b/Applications/Classification/CMakeLists.txt @@ -4,4 +4,4 @@ OTB_CREATE_APPLICATION(NAME EstimateImagesStatistics OTB_CREATE_APPLICATION(NAME ImageSVMClassifier SOURCES otbImageSVMClassifier.cxx - LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters) + LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters;OTBFeatureExtraction;OTBLearning;OTBApplicationEngine) diff --git a/Applications/Classification/otbImageSVMClassifier.cxx b/Applications/Classification/otbImageSVMClassifier.cxx index a4a7da2775..d1a2ca34c7 100644 --- a/Applications/Classification/otbImageSVMClassifier.cxx +++ b/Applications/Classification/otbImageSVMClassifier.cxx @@ -82,13 +82,14 @@ private: AddParameter(ParameterType_InputImage, "mask", "Input Mask to classify"); SetParameterDescription( "mask", "A mask associated with the new image to classify"); + MandatoryOff("mask"); AddParameter(ParameterType_Filename, "imstat", "Image statistics file."); SetParameterDescription("imstat", "a XML file containing mean and standard deviation of input images used to train svm model."); - MandatoryOff("instat"); + MandatoryOff("imstat"); - AddParameter(ParameterType_Filename, "svmmodel", "SVM Model."); - SetParameterDescription("svmmodel", "An estimated svm model previously computed"); + AddParameter(ParameterType_Filename, "svm", "SVM Model."); + SetParameterDescription("svm", "An estimated svm model previously computed"); AddParameter(ParameterType_OutputImage, "out", "Output Image"); SetParameterDescription( "out", "Output labeled image"); @@ -109,21 +110,20 @@ private: inImage->UpdateOutputInformation(); // Load svm model - ModelPointerType modelSVM = ModelType::New(); - modelSVM->LoadModel(GetParameterString("svmmodel").c_str()); + m_ModelSVM = ModelType::New(); + m_ModelSVM->LoadModel(GetParameterString("svm").c_str()); // Normalize input image (optional) StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); MeasurementType meanMeasurementVector; MeasurementType stddevMeasurementVector; - RescalerType::Pointer rescaler = RescalerType::New(); + m_Rescaler = RescalerType::New(); // Classify - ClassificationFilterType::Pointer classificationFilter = ClassificationFilterType::New(); - classificationFilter->SetModel(modelSVM); - - + m_ClassificationFilter = ClassificationFilterType::New(); + m_ClassificationFilter->SetModel(m_ModelSVM); + // Normalize input image if asked if( HasValue("imstat") ) { @@ -135,18 +135,19 @@ private: otbAppLogDEBUG( "mean used: " << meanMeasurementVector ); otbAppLogDEBUG( "standard deviation used: " << stddevMeasurementVector ); // Rescale vector image - rescaler->SetScale(stddevMeasurementVector); - rescaler->SetShift(meanMeasurementVector); - rescaler->SetInput(inImage); + m_Rescaler->SetScale(stddevMeasurementVector); + m_Rescaler->SetShift(meanMeasurementVector); + m_Rescaler->SetInput(inImage); - classificationFilter->SetInput(rescaler->GetOutput()); + m_ClassificationFilter->SetInput(m_Rescaler->GetOutput()); } else { otbAppLogDEBUG("Input image normalization deactivated."); - classificationFilter->SetInput(inImage); + m_ClassificationFilter->SetInput(inImage); } + if( HasValue("mask") ) { otbAppLogDEBUG("Use input mask."); @@ -157,17 +158,24 @@ private: extract->SetChannel(0); extract->UpdateOutputInformation(); - classificationFilter->SetInputMask(extract->GetOutput()); + m_ClassificationFilter->SetInputMask(extract->GetOutput()); } + + std::cout<<"-------------3-----------------"<<std::endl; + m_FinalCast = CastImageFilterType::New(); + m_FinalCast->SetInput( m_ClassificationFilter->GetOutput() ); - CastImageFilterType::Pointer finalCast = CastImageFilterType::New(); - finalCast->SetInput( classificationFilter->GetOutput() ); + SetParameterOutputImage("out", m_FinalCast->GetOutput()); - SetParameterOutputImage("out", finalCast->GetOutput()); + //SetParameterOuutputImage<UInt8ImageType>("out", m_ClassificationFilter->GetOutput()); + std::cout<<"---------------4---------------"<<std::endl; } - //itk::LightObject::Pointer m_FilterRef; + ClassificationFilterType::Pointer m_ClassificationFilter; + ModelPointerType m_ModelSVM; + RescalerType::Pointer m_Rescaler; + CastImageFilterType::Pointer m_FinalCast; }; diff --git a/Code/ApplicationEngine/otbWrapperApplication.cxx b/Code/ApplicationEngine/otbWrapperApplication.cxx index b2c52b1e5f..bd0a2c15a6 100644 --- a/Code/ApplicationEngine/otbWrapperApplication.cxx +++ b/Code/ApplicationEngine/otbWrapperApplication.cxx @@ -435,6 +435,7 @@ void Application::SetParameterStringList(std::string parameter, std::vector<std: void Application::SetParameterOutputImage(std::string parameter, FloatVectorImageType* value) { + std::cout<<"Application::SetParameterOutputImage 1"<<std::endl; Parameter* param = GetParameterByKey(parameter); if (dynamic_cast<OutputImageParameter*>(param)) @@ -444,6 +445,21 @@ void Application::SetParameterOutputImage(std::string parameter, FloatVectorImag } } +template <class TImageType> +void Application::SetParameterOuutputImage(std::string parameter, TImageType* value) +{ std::cout<<"Application::SetParameterOutputImage"<<std::endl; + Parameter* param = GetParameterByKey(parameter); + + if (dynamic_cast<OutputImageParameter*>(param)) + { + std::cout<<"Application::SetParameterOutputImage 0000"<<std::endl; + OutputImageParameter* paramDown = dynamic_cast<OutputImageParameter*>(param); + paramDown->SetValue(value); + std::cout<<"Application::SetParameterOutputImage plop"<<std::endl; + } +} + + void Application::SetParameterOutputImagePixelType(std::string parameter, ImagePixelType pixelType) { Parameter* param = GetParameterByKey(parameter); diff --git a/Code/ApplicationEngine/otbWrapperApplication.h b/Code/ApplicationEngine/otbWrapperApplication.h index 045f8aca71..11a2ce79f8 100644 --- a/Code/ApplicationEngine/otbWrapperApplication.h +++ b/Code/ApplicationEngine/otbWrapperApplication.h @@ -214,6 +214,15 @@ public: */ void SetParameterOutputImage(std::string parameter, FloatVectorImageType* value); + /* Set an output image value + * + * Can be called for types : + * \li ParameterType_OutputImage + */ + template <class TImageType> + void SetParameterOuutputImage(std::string parameter, TImageType* value); + + /* Set the pixel type in which the image will be saved * * Can be called for types : diff --git a/Code/ApplicationEngine/otbWrapperOutputImageParameter.cxx b/Code/ApplicationEngine/otbWrapperOutputImageParameter.cxx index 0e24687fee..51187fc4cf 100644 --- a/Code/ApplicationEngine/otbWrapperOutputImageParameter.cxx +++ b/Code/ApplicationEngine/otbWrapperOutputImageParameter.cxx @@ -75,6 +75,7 @@ template <class TInputImageType> void OutputImageParameter::SwitchImageWrite() { + std::cout<<"OutputImageParameter::SwitchImageWrite start"<<std::endl; switch(m_PixelType ) { case ImagePixelType_int8: @@ -84,6 +85,7 @@ OutputImageParameter::SwitchImageWrite() } case ImagePixelType_uint8: { + std::cout<<"OutputImageParameter::SwitchImageWrite UNIN8"<<std::endl; otbRescaleAndWriteMacro(TInputImageType, UInt8ImageType, m_UInt8Writer); break; } @@ -118,6 +120,7 @@ OutputImageParameter::SwitchImageWrite() break; } } + std::cout<<"OutputImageParameter::SwitchImageWrite end"<<std::endl; } @@ -176,14 +179,16 @@ OutputImageParameter::SwitchVectorImageWrite() void OutputImageParameter::Write() { + std::cout<<"OutputImageParameter::Write"<<std::endl; m_Image->UpdateOutputInformation(); - + std::cout<<"OutputImageParameter::Write1"<<std::endl; if (dynamic_cast<Int8ImageType*>(m_Image.GetPointer())) { SwitchImageWrite<Int8ImageType>(); } else if (dynamic_cast<UInt8ImageType*>(m_Image.GetPointer())) { + std::cout<<"OutputImageParameter::Write UNIN8"<<std::endl; SwitchImageWrite<UInt8ImageType>(); } else if (dynamic_cast<Int16ImageType*>(m_Image.GetPointer())) @@ -246,6 +251,8 @@ OutputImageParameter::Write() { itkExceptionMacro("Unknown image type"); } + + std::cout<<"OutputImageParameter::Write2"<<std::endl; } -- GitLab