Commit ea7f1bcf authored by Cyrille Valladeau's avatar Cyrille Valladeau

ENH: correct ImageSVMClassifier appli

parent d3175280
......@@ -4,4 +4,4 @@ OTB_CREATE_APPLICATION(NAME EstimateImagesStatistics
OTB_CREATE_APPLICATION(NAME ImageSVMClassifier
SOURCES otbImageSVMClassifier.cxx
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters)
LINK_LIBRARIES OTBIO;OTBCommon;OTBBasicFilters;OTBFeatureExtraction;OTBLearning;OTBApplicationEngine)
......@@ -82,13 +82,14 @@ private:
AddParameter(ParameterType_InputImage, "mask", "Input Mask to classify");
SetParameterDescription( "mask", "A mask associated with the new image to classify");
MandatoryOff("mask");
AddParameter(ParameterType_Filename, "imstat", "Image statistics file.");
SetParameterDescription("imstat", "a XML file containing mean and standard deviation of input images used to train svm model.");
MandatoryOff("instat");
MandatoryOff("imstat");
AddParameter(ParameterType_Filename, "svmmodel", "SVM Model.");
SetParameterDescription("svmmodel", "An estimated svm model previously computed");
AddParameter(ParameterType_Filename, "svm", "SVM Model.");
SetParameterDescription("svm", "An estimated svm model previously computed");
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription( "out", "Output labeled image");
......@@ -109,21 +110,20 @@ private:
inImage->UpdateOutputInformation();
// Load svm model
ModelPointerType modelSVM = ModelType::New();
modelSVM->LoadModel(GetParameterString("svmmodel").c_str());
m_ModelSVM = ModelType::New();
m_ModelSVM->LoadModel(GetParameterString("svm").c_str());
// Normalize input image (optional)
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
RescalerType::Pointer rescaler = RescalerType::New();
m_Rescaler = RescalerType::New();
// Classify
ClassificationFilterType::Pointer classificationFilter = ClassificationFilterType::New();
classificationFilter->SetModel(modelSVM);
m_ClassificationFilter = ClassificationFilterType::New();
m_ClassificationFilter->SetModel(m_ModelSVM);
// Normalize input image if asked
if( HasValue("imstat") )
{
......@@ -135,18 +135,19 @@ private:
otbAppLogDEBUG( "mean used: " << meanMeasurementVector );
otbAppLogDEBUG( "standard deviation used: " << stddevMeasurementVector );
// Rescale vector image
rescaler->SetScale(stddevMeasurementVector);
rescaler->SetShift(meanMeasurementVector);
rescaler->SetInput(inImage);
m_Rescaler->SetScale(stddevMeasurementVector);
m_Rescaler->SetShift(meanMeasurementVector);
m_Rescaler->SetInput(inImage);
classificationFilter->SetInput(rescaler->GetOutput());
m_ClassificationFilter->SetInput(m_Rescaler->GetOutput());
}
else
{
otbAppLogDEBUG("Input image normalization deactivated.");
classificationFilter->SetInput(inImage);
m_ClassificationFilter->SetInput(inImage);
}
if( HasValue("mask") )
{
otbAppLogDEBUG("Use input mask.");
......@@ -157,17 +158,24 @@ private:
extract->SetChannel(0);
extract->UpdateOutputInformation();
classificationFilter->SetInputMask(extract->GetOutput());
m_ClassificationFilter->SetInputMask(extract->GetOutput());
}
std::cout<<"-------------3-----------------"<<std::endl;
m_FinalCast = CastImageFilterType::New();
m_FinalCast->SetInput( m_ClassificationFilter->GetOutput() );
CastImageFilterType::Pointer finalCast = CastImageFilterType::New();
finalCast->SetInput( classificationFilter->GetOutput() );
SetParameterOutputImage("out", m_FinalCast->GetOutput());
SetParameterOutputImage("out", finalCast->GetOutput());
//SetParameterOuutputImage<UInt8ImageType>("out", m_ClassificationFilter->GetOutput());
std::cout<<"---------------4---------------"<<std::endl;
}
//itk::LightObject::Pointer m_FilterRef;
ClassificationFilterType::Pointer m_ClassificationFilter;
ModelPointerType m_ModelSVM;
RescalerType::Pointer m_Rescaler;
CastImageFilterType::Pointer m_FinalCast;
};
......
......@@ -435,6 +435,7 @@ void Application::SetParameterStringList(std::string parameter, std::vector<std:
void Application::SetParameterOutputImage(std::string parameter, FloatVectorImageType* value)
{
std::cout<<"Application::SetParameterOutputImage 1"<<std::endl;
Parameter* param = GetParameterByKey(parameter);
if (dynamic_cast<OutputImageParameter*>(param))
......@@ -444,6 +445,21 @@ void Application::SetParameterOutputImage(std::string parameter, FloatVectorImag
}
}
template <class TImageType>
void Application::SetParameterOuutputImage(std::string parameter, TImageType* value)
{ std::cout<<"Application::SetParameterOutputImage"<<std::endl;
Parameter* param = GetParameterByKey(parameter);
if (dynamic_cast<OutputImageParameter*>(param))
{
std::cout<<"Application::SetParameterOutputImage 0000"<<std::endl;
OutputImageParameter* paramDown = dynamic_cast<OutputImageParameter*>(param);
paramDown->SetValue(value);
std::cout<<"Application::SetParameterOutputImage plop"<<std::endl;
}
}
void Application::SetParameterOutputImagePixelType(std::string parameter, ImagePixelType pixelType)
{
Parameter* param = GetParameterByKey(parameter);
......
......@@ -214,6 +214,15 @@ public:
*/
void SetParameterOutputImage(std::string parameter, FloatVectorImageType* value);
/* Set an output image value
*
* Can be called for types :
* \li ParameterType_OutputImage
*/
template <class TImageType>
void SetParameterOuutputImage(std::string parameter, TImageType* value);
/* Set the pixel type in which the image will be saved
*
* Can be called for types :
......
......@@ -75,6 +75,7 @@ template <class TInputImageType>
void
OutputImageParameter::SwitchImageWrite()
{
std::cout<<"OutputImageParameter::SwitchImageWrite start"<<std::endl;
switch(m_PixelType )
{
case ImagePixelType_int8:
......@@ -84,6 +85,7 @@ OutputImageParameter::SwitchImageWrite()
}
case ImagePixelType_uint8:
{
std::cout<<"OutputImageParameter::SwitchImageWrite UNIN8"<<std::endl;
otbRescaleAndWriteMacro(TInputImageType, UInt8ImageType, m_UInt8Writer);
break;
}
......@@ -118,6 +120,7 @@ OutputImageParameter::SwitchImageWrite()
break;
}
}
std::cout<<"OutputImageParameter::SwitchImageWrite end"<<std::endl;
}
......@@ -176,14 +179,16 @@ OutputImageParameter::SwitchVectorImageWrite()
void
OutputImageParameter::Write()
{
std::cout<<"OutputImageParameter::Write"<<std::endl;
m_Image->UpdateOutputInformation();
std::cout<<"OutputImageParameter::Write1"<<std::endl;
if (dynamic_cast<Int8ImageType*>(m_Image.GetPointer()))
{
SwitchImageWrite<Int8ImageType>();
}
else if (dynamic_cast<UInt8ImageType*>(m_Image.GetPointer()))
{
std::cout<<"OutputImageParameter::Write UNIN8"<<std::endl;
SwitchImageWrite<UInt8ImageType>();
}
else if (dynamic_cast<Int16ImageType*>(m_Image.GetPointer()))
......@@ -246,6 +251,8 @@ OutputImageParameter::Write()
{
itkExceptionMacro("Unknown image type");
}
std::cout<<"OutputImageParameter::Write2"<<std::endl;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment