Skip to content
Snippets Groups Projects
Commit ea7f1bcf authored by Cyrille Valladeau's avatar Cyrille Valladeau
Browse files

ENH: correct ImageSVMClassifier appli

parent d3175280
No related branches found
No related tags found
No related merge requests found
...@@ -4,4 +4,4 @@ OTB_CREATE_APPLICATION(NAME EstimateImagesStatistics ...@@ -4,4 +4,4 @@ OTB_CREATE_APPLICATION(NAME EstimateImagesStatistics
OTB_CREATE_APPLICATION(NAME ImageSVMClassifier OTB_CREATE_APPLICATION(NAME ImageSVMClassifier
SOURCES otbImageSVMClassifier.cxx SOURCES otbImageSVMClassifier.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters) LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters;OTBFeatureExtraction;OTBLearning;OTBApplicationEngine)
...@@ -82,13 +82,14 @@ private: ...@@ -82,13 +82,14 @@ private:
AddParameter(ParameterType_InputImage, "mask", "Input Mask to classify"); AddParameter(ParameterType_InputImage, "mask", "Input Mask to classify");
SetParameterDescription( "mask", "A mask associated with the new image to classify"); SetParameterDescription( "mask", "A mask associated with the new image to classify");
MandatoryOff("mask");
AddParameter(ParameterType_Filename, "imstat", "Image statistics file."); AddParameter(ParameterType_Filename, "imstat", "Image statistics file.");
SetParameterDescription("imstat", "a XML file containing mean and standard deviation of input images used to train svm model."); SetParameterDescription("imstat", "a XML file containing mean and standard deviation of input images used to train svm model.");
MandatoryOff("instat"); MandatoryOff("imstat");
AddParameter(ParameterType_Filename, "svmmodel", "SVM Model."); AddParameter(ParameterType_Filename, "svm", "SVM Model.");
SetParameterDescription("svmmodel", "An estimated svm model previously computed"); SetParameterDescription("svm", "An estimated svm model previously computed");
AddParameter(ParameterType_OutputImage, "out", "Output Image"); AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output labeled image"); SetParameterDescription( "out", "Output labeled image");
...@@ -109,21 +110,20 @@ private: ...@@ -109,21 +110,20 @@ private:
inImage->UpdateOutputInformation(); inImage->UpdateOutputInformation();
// Load svm model // Load svm model
ModelPointerType modelSVM = ModelType::New(); m_ModelSVM = ModelType::New();
modelSVM->LoadModel(GetParameterString("svmmodel").c_str()); m_ModelSVM->LoadModel(GetParameterString("svm").c_str());
// Normalize input image (optional) // Normalize input image (optional)
StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
MeasurementType meanMeasurementVector; MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector; MeasurementType stddevMeasurementVector;
RescalerType::Pointer rescaler = RescalerType::New(); m_Rescaler = RescalerType::New();
// Classify // Classify
ClassificationFilterType::Pointer classificationFilter = ClassificationFilterType::New(); m_ClassificationFilter = ClassificationFilterType::New();
classificationFilter->SetModel(modelSVM); m_ClassificationFilter->SetModel(m_ModelSVM);
// Normalize input image if asked // Normalize input image if asked
if( HasValue("imstat") ) if( HasValue("imstat") )
{ {
...@@ -135,18 +135,19 @@ private: ...@@ -135,18 +135,19 @@ private:
otbAppLogDEBUG( "mean used: " << meanMeasurementVector ); otbAppLogDEBUG( "mean used: " << meanMeasurementVector );
otbAppLogDEBUG( "standard deviation used: " << stddevMeasurementVector ); otbAppLogDEBUG( "standard deviation used: " << stddevMeasurementVector );
// Rescale vector image // Rescale vector image
rescaler->SetScale(stddevMeasurementVector); m_Rescaler->SetScale(stddevMeasurementVector);
rescaler->SetShift(meanMeasurementVector); m_Rescaler->SetShift(meanMeasurementVector);
rescaler->SetInput(inImage); m_Rescaler->SetInput(inImage);
classificationFilter->SetInput(rescaler->GetOutput()); m_ClassificationFilter->SetInput(m_Rescaler->GetOutput());
} }
else else
{ {
otbAppLogDEBUG("Input image normalization deactivated."); otbAppLogDEBUG("Input image normalization deactivated.");
classificationFilter->SetInput(inImage); m_ClassificationFilter->SetInput(inImage);
} }
if( HasValue("mask") ) if( HasValue("mask") )
{ {
otbAppLogDEBUG("Use input mask."); otbAppLogDEBUG("Use input mask.");
...@@ -157,17 +158,24 @@ private: ...@@ -157,17 +158,24 @@ private:
extract->SetChannel(0); extract->SetChannel(0);
extract->UpdateOutputInformation(); extract->UpdateOutputInformation();
classificationFilter->SetInputMask(extract->GetOutput()); m_ClassificationFilter->SetInputMask(extract->GetOutput());
} }
std::cout<<"-------------3-----------------"<<std::endl;
m_FinalCast = CastImageFilterType::New();
m_FinalCast->SetInput( m_ClassificationFilter->GetOutput() );
CastImageFilterType::Pointer finalCast = CastImageFilterType::New(); SetParameterOutputImage("out", m_FinalCast->GetOutput());
finalCast->SetInput( classificationFilter->GetOutput() );
SetParameterOutputImage("out", finalCast->GetOutput()); //SetParameterOuutputImage<UInt8ImageType>("out", m_ClassificationFilter->GetOutput());
std::cout<<"---------------4---------------"<<std::endl;
} }
//itk::LightObject::Pointer m_FilterRef; ClassificationFilterType::Pointer m_ClassificationFilter;
ModelPointerType m_ModelSVM;
RescalerType::Pointer m_Rescaler;
CastImageFilterType::Pointer m_FinalCast;
}; };
......
...@@ -435,6 +435,7 @@ void Application::SetParameterStringList(std::string parameter, std::vector<std: ...@@ -435,6 +435,7 @@ void Application::SetParameterStringList(std::string parameter, std::vector<std:
void Application::SetParameterOutputImage(std::string parameter, FloatVectorImageType* value) void Application::SetParameterOutputImage(std::string parameter, FloatVectorImageType* value)
{ {
std::cout<<"Application::SetParameterOutputImage 1"<<std::endl;
Parameter* param = GetParameterByKey(parameter); Parameter* param = GetParameterByKey(parameter);
if (dynamic_cast<OutputImageParameter*>(param)) if (dynamic_cast<OutputImageParameter*>(param))
...@@ -444,6 +445,21 @@ void Application::SetParameterOutputImage(std::string parameter, FloatVectorImag ...@@ -444,6 +445,21 @@ void Application::SetParameterOutputImage(std::string parameter, FloatVectorImag
} }
} }
template <class TImageType>
void Application::SetParameterOuutputImage(std::string parameter, TImageType* value)
{ std::cout<<"Application::SetParameterOutputImage"<<std::endl;
Parameter* param = GetParameterByKey(parameter);
if (dynamic_cast<OutputImageParameter*>(param))
{
std::cout<<"Application::SetParameterOutputImage 0000"<<std::endl;
OutputImageParameter* paramDown = dynamic_cast<OutputImageParameter*>(param);
paramDown->SetValue(value);
std::cout<<"Application::SetParameterOutputImage plop"<<std::endl;
}
}
void Application::SetParameterOutputImagePixelType(std::string parameter, ImagePixelType pixelType) void Application::SetParameterOutputImagePixelType(std::string parameter, ImagePixelType pixelType)
{ {
Parameter* param = GetParameterByKey(parameter); Parameter* param = GetParameterByKey(parameter);
......
...@@ -214,6 +214,15 @@ public: ...@@ -214,6 +214,15 @@ public:
*/ */
void SetParameterOutputImage(std::string parameter, FloatVectorImageType* value); void SetParameterOutputImage(std::string parameter, FloatVectorImageType* value);
/* Set an output image value
*
* Can be called for types :
* \li ParameterType_OutputImage
*/
template <class TImageType>
void SetParameterOuutputImage(std::string parameter, TImageType* value);
/* Set the pixel type in which the image will be saved /* Set the pixel type in which the image will be saved
* *
* Can be called for types : * Can be called for types :
......
...@@ -75,6 +75,7 @@ template <class TInputImageType> ...@@ -75,6 +75,7 @@ template <class TInputImageType>
void void
OutputImageParameter::SwitchImageWrite() OutputImageParameter::SwitchImageWrite()
{ {
std::cout<<"OutputImageParameter::SwitchImageWrite start"<<std::endl;
switch(m_PixelType ) switch(m_PixelType )
{ {
case ImagePixelType_int8: case ImagePixelType_int8:
...@@ -84,6 +85,7 @@ OutputImageParameter::SwitchImageWrite() ...@@ -84,6 +85,7 @@ OutputImageParameter::SwitchImageWrite()
} }
case ImagePixelType_uint8: case ImagePixelType_uint8:
{ {
std::cout<<"OutputImageParameter::SwitchImageWrite UNIN8"<<std::endl;
otbRescaleAndWriteMacro(TInputImageType, UInt8ImageType, m_UInt8Writer); otbRescaleAndWriteMacro(TInputImageType, UInt8ImageType, m_UInt8Writer);
break; break;
} }
...@@ -118,6 +120,7 @@ OutputImageParameter::SwitchImageWrite() ...@@ -118,6 +120,7 @@ OutputImageParameter::SwitchImageWrite()
break; break;
} }
} }
std::cout<<"OutputImageParameter::SwitchImageWrite end"<<std::endl;
} }
...@@ -176,14 +179,16 @@ OutputImageParameter::SwitchVectorImageWrite() ...@@ -176,14 +179,16 @@ OutputImageParameter::SwitchVectorImageWrite()
void void
OutputImageParameter::Write() OutputImageParameter::Write()
{ {
std::cout<<"OutputImageParameter::Write"<<std::endl;
m_Image->UpdateOutputInformation(); m_Image->UpdateOutputInformation();
std::cout<<"OutputImageParameter::Write1"<<std::endl;
if (dynamic_cast<Int8ImageType*>(m_Image.GetPointer())) if (dynamic_cast<Int8ImageType*>(m_Image.GetPointer()))
{ {
SwitchImageWrite<Int8ImageType>(); SwitchImageWrite<Int8ImageType>();
} }
else if (dynamic_cast<UInt8ImageType*>(m_Image.GetPointer())) else if (dynamic_cast<UInt8ImageType*>(m_Image.GetPointer()))
{ {
std::cout<<"OutputImageParameter::Write UNIN8"<<std::endl;
SwitchImageWrite<UInt8ImageType>(); SwitchImageWrite<UInt8ImageType>();
} }
else if (dynamic_cast<Int16ImageType*>(m_Image.GetPointer())) else if (dynamic_cast<Int16ImageType*>(m_Image.GetPointer()))
...@@ -246,6 +251,8 @@ OutputImageParameter::Write() ...@@ -246,6 +251,8 @@ OutputImageParameter::Write()
{ {
itkExceptionMacro("Unknown image type"); itkExceptionMacro("Unknown image type");
} }
std::cout<<"OutputImageParameter::Write2"<<std::endl;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment