diff --git a/Applications/Classification/otbTrainMachineLearningImagesClassifier.cxx b/Applications/Classification/otbTrainMachineLearningImagesClassifier.cxx index c03515a87da75a26f2b8c214a7e7be59be25b150..b8d92219545de0c8d5be490d69e8bd39fdbd3784 100644 --- a/Applications/Classification/otbTrainMachineLearningImagesClassifier.cxx +++ b/Applications/Classification/otbTrainMachineLearningImagesClassifier.cxx @@ -213,8 +213,8 @@ private: AddChoice("classifier.svm.m.csvc", "C support vector classification"); AddChoice("classifier.svm.m.nusvc", "Nu support vector classification"); AddChoice("classifier.svm.m.oneclass", "Distribution estimation (One Class SVM)"); - AddChoice("classifier.svm.m.epssvr", "Epsilon Support Vector Regression"); - AddChoice("classifier.svm.m.nusvr", "Nu Support Vector Regression"); + //AddChoice("classifier.svm.m.epssvr", "Epsilon Support Vector Regression"); + //AddChoice("classifier.svm.m.nusvr", "Nu Support Vector Regression"); SetParameterString("classifier.svm.m", "csvc"); SetParameterDescription("classifier.svm.m", "Type of SVM formulation."); AddParameter(ParameterType_Choice, "classifier.svm.k", "SVM Kernel Type"); @@ -230,9 +230,9 @@ private: AddParameter(ParameterType_Float, "classifier.svm.nu", "Parameter nu of a SVM optimization problem (NU_SVC / ONE_CLASS / NU_SVR)."); SetParameterFloat("classifier.svm.nu", 0.0); SetParameterDescription("classifier.svm.nu", "Parameter nu of a SVM optimization problem."); - AddParameter(ParameterType_Float, "classifier.svm.p", "Parameter epsilon of a SVM optimization problem (EPS_SVR)."); - SetParameterFloat("classifier.svm.p", 0.0); - SetParameterDescription("classifier.svm.p", "Parameter epsilon of a SVM optimization problem (EPS_SVR)."); + //AddParameter(ParameterType_Float, "classifier.svm.p", "Parameter epsilon of a SVM optimization problem (EPS_SVR)."); + //SetParameterFloat("classifier.svm.p", 0.0); + //SetParameterDescription("classifier.svm.p", "Parameter epsilon of a SVM optimization problem (EPS_SVR)."); AddParameter(ParameterType_Float, "classifier.svm.coef0", "Parameter coef0 of a kernel function (POLY / SIGMOID)."); SetParameterFloat("classifier.svm.coef0", 0.0); SetParameterDescription("classifier.svm.coef0", "Parameter coef0 of a kernel function (POLY / SIGMOID)."); @@ -423,14 +423,14 @@ private: SVMClassifier->SetSVMType(CvSVM::ONE_CLASS); std::cout<<"CvSVM::ONE_CLASS = "<<CvSVM::ONE_CLASS<<std::endl; break; - case 3: // EPS_SVR + /*case 3: // EPS_SVR SVMClassifier->SetSVMType(CvSVM::EPS_SVR); std::cout<<"CvSVM::EPS_SVR = "<<CvSVM::EPS_SVR<<std::endl; break; case 4: // NU_SVR SVMClassifier->SetSVMType(CvSVM::NU_SVR); std::cout<<"CvSVM::NU_SVR = "<<CvSVM::NU_SVR<<std::endl; - break; + break;*/ default: // DEFAULT = C_SVC SVMClassifier->SetSVMType(CvSVM::C_SVC); std::cout<<"CvSVM::C_SVC = "<<CvSVM::C_SVC<<std::endl; @@ -438,7 +438,7 @@ private: } SVMClassifier->SetC(GetParameterFloat("classifier.svm.c")); SVMClassifier->SetNu(GetParameterFloat("classifier.svm.nu")); - SVMClassifier->SetP(GetParameterFloat("classifier.svm.p")); + //SVMClassifier->SetP(GetParameterFloat("classifier.svm.p")); SVMClassifier->SetCoef0(GetParameterFloat("classifier.svm.coef0")); SVMClassifier->SetGamma(GetParameterFloat("classifier.svm.gamma")); SVMClassifier->SetDegree(GetParameterFloat("classifier.svm.degree")); diff --git a/Code/Learning/otbImageClassificationFilter.h b/Code/Learning/otbImageClassificationFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..4ac0f63f16aafc89108845f62fb7f6e0ce2430b2 --- /dev/null +++ b/Code/Learning/otbImageClassificationFilter.h @@ -0,0 +1,118 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbImageClassificationFilter_h +#define __otbImageClassificationFilter_h + +#include "itkImageToImageFilter.h" +#include "otbMachineLearningModel.h" + +namespace otb +{ +/** \class ImageClassificationFilter + * \brief This filter performs the classification of a VectorImage using a Model. + * + * This filter is streamed and threaded, allowing to classify huge images + * while fully using several core. + * + * \sa Classifier + * \ingroup Streamed + * \ingroup Threaded + */ +template <class TInputImage, class TOutputImage, class TMaskImage = TOutputImage> +class ITK_EXPORT ImageClassificationFilter + : public itk::ImageToImageFilter<TInputImage, TOutputImage> +{ +public: + /** Standard typedefs */ + typedef ImageClassificationFilter Self; + typedef itk::ImageToImageFilter<TInputImage, TOutputImage> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Type macro */ + itkNewMacro(Self); + + /** Creation through object factory macro */ + itkTypeMacro(ImageClassificationFilter, ImageToImageFilter); + + typedef TInputImage InputImageType; + typedef typename InputImageType::ConstPointer InputImageConstPointerType; + typedef typename InputImageType::InternalPixelType ValueType; + + typedef TMaskImage MaskImageType; + typedef typename MaskImageType::ConstPointer MaskImageConstPointerType; + typedef typename MaskImageType::Pointer MaskImagePointerType; + + typedef TOutputImage OutputImageType; + typedef typename OutputImageType::Pointer OutputImagePointerType; + typedef typename OutputImageType::RegionType OutputImageRegionType; + typedef typename OutputImageType::PixelType LabelType; + + typedef MachineLearningModel<ValueType, LabelType> ModelType; + typedef typename ModelType::Pointer ModelPointerType; + + /** Set/Get the model */ + itkSetObjectMacro(Model, ModelType); + itkGetObjectMacro(Model, ModelType); + + /** Set/Get the default label */ + itkSetMacro(DefaultLabel, LabelType); + itkGetMacro(DefaultLabel, LabelType); + + /** + * If set, only pixels within the mask will be classified. + * All pixels with a value greater than 0 in the mask, will be classified. + * \param mask The input mask. + */ + void SetInputMask(const MaskImageType * mask); + + /** + * Get the input mask. + * \return The mask. + */ + const MaskImageType * GetInputMask(void); + +protected: + /** Constructor */ + ImageClassificationFilter(); + /** Destructor */ + virtual ~ImageClassificationFilter() {} + + /** Threaded generate data */ + virtual void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId); + /** Before threaded generate data */ + virtual void BeforeThreadedGenerateData(); + /**PrintSelf method */ + virtual void PrintSelf(std::ostream& os, itk::Indent indent) const; + +private: + ImageClassificationFilter(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + /** The model used for classification */ + ModelPointerType m_Model; + /** Default label for invalid pixels (when using a mask) */ + LabelType m_DefaultLabel; + +}; +} // End namespace otb +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbImageClassificationFilter.txx" +#endif + +#endif diff --git a/Code/Learning/otbImageClassificationFilter.txx b/Code/Learning/otbImageClassificationFilter.txx new file mode 100644 index 0000000000000000000000000000000000000000..26364cf0c330c34ef7e01cbb5fabe1cabfaefb10 --- /dev/null +++ b/Code/Learning/otbImageClassificationFilter.txx @@ -0,0 +1,137 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbImageClassificationFilter_txx +#define __otbImageClassificationFilter_txx + +#include "otbImageClassificationFilter.h" +#include "itkImageRegionIterator.h" +#include "itkProgressReporter.h" + +namespace otb +{ +/** + * Constructor + */ +template <class TInputImage, class TOutputImage, class TMaskImage> +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ImageClassificationFilter() +{ + this->SetNumberOfInputs(2); + this->SetNumberOfRequiredInputs(1); + m_DefaultLabel = itk::NumericTraits<LabelType>::ZeroValue(); +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::SetInputMask(const MaskImageType * mask) +{ + this->itk::ProcessObject::SetNthInput(1, const_cast<MaskImageType *>(mask)); +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +const typename ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::MaskImageType * +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::GetInputMask() +{ + if (this->GetNumberOfInputs() < 2) + { + return 0; + } + return static_cast<const MaskImageType *>(this->itk::ProcessObject::GetInput(1)); +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::BeforeThreadedGenerateData() +{ + if (!m_Model) + { + itkGenericExceptionMacro(<< "No model for classification"); + } +} + +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId) +{ + // Get the input pointers + InputImageConstPointerType inputPtr = this->GetInput(); + MaskImageConstPointerType inputMaskPtr = this->GetInputMask(); + OutputImagePointerType outputPtr = this->GetOutput(); + + // Progress reporting + itk::ProgressReporter progress(this, threadId, outputRegionForThread.GetNumberOfPixels()); + + // Define iterators + typedef itk::ImageRegionConstIterator<InputImageType> InputIteratorType; + typedef itk::ImageRegionConstIterator<MaskImageType> MaskIteratorType; + typedef itk::ImageRegionIterator<OutputImageType> OutputIteratorType; + + InputIteratorType inIt(inputPtr, outputRegionForThread); + OutputIteratorType outIt(outputPtr, outputRegionForThread); + + // Eventually iterate on masks + MaskIteratorType maskIt; + if (inputMaskPtr) + { + maskIt = MaskIteratorType(inputMaskPtr, outputRegionForThread); + maskIt.GoToBegin(); + } + + bool validPoint = true; + + // Walk the part of the image + for (inIt.GoToBegin(), outIt.GoToBegin(); !inIt.IsAtEnd() && !outIt.IsAtEnd(); ++inIt, ++outIt) + { + // Check pixel validity + if (inputMaskPtr) + { + validPoint = maskIt.Get() > 0; + ++maskIt; + } + // If point is valid + if (validPoint) + { + // Classifify + outIt.Set(m_Model->Predict(inIt.Get())[0]); + } + else + { + // else, set default value + outIt.Set(m_DefaultLabel); + } + progress.CompletedPixel(); + } + +} +/** + * PrintSelf Method + */ +template <class TInputImage, class TOutputImage, class TMaskImage> +void +ImageClassificationFilter<TInputImage, TOutputImage, TMaskImage> +::PrintSelf(std::ostream& os, itk::Indent indent) const +{ + Superclass::PrintSelf(os, indent); +} +} // End namespace otb +#endif diff --git a/Testing/Code/Learning/CMakeLists.txt b/Testing/Code/Learning/CMakeLists.txt index 4e401ce909b98d2ad07c97618837f81879a8226d..b95d0657f8e55f8d293fe40dd05bf2cc1976775a 100644 --- a/Testing/Code/Learning/CMakeLists.txt +++ b/Testing/Code/Learning/CMakeLists.txt @@ -724,6 +724,48 @@ IF(OTB_USE_OPENCV) ${INPUTDATA}/letter.scale ${TEMP}/libsvm_model.txt ) + + ADD_TEST(leTuImageClassificationFilterNew ${LEARNING_TESTS6} + otbImageClassificationFilterNew) + + ADD_TEST(leTvImageClassificationFilterLibSVM ${LEARNING_TESTS6} + --compare-image ${NOTOL} + ${BASELINE}/leSVMImageClassificationFilterOutput.tif + ${TEMP}/leImageClassificationFilterLibSVMOutput.tif + otbImageClassificationFilter + ${INPUTDATA}/ROI_QB_MUL_4.tif + ${INPUTDATA}/svm_model_image + ${TEMP}/leImageClassificationFilterLibSVMOutput.tif + ) + + ADD_TEST(leTvImageClassificationFilterSVM ${LEARNING_TESTS6} + --compare-image ${NOTOL} + ${BASELINE}/leImageClassificationFilterSVMOutput.tif + ${TEMP}/leImageClassificationFilterSVMOutput.tif + otbImageClassificationFilter + ${INPUTDATA}/ROI_QB_MUL_4.tif + ${INPUTDATA}/ROI_QB_MUL_4_svmModel.txt + ${TEMP}/leImageClassificationFilterSVMOutput.tif + ) + + ADD_TEST(leTuLibSVMMachineLearningModelCanRead ${LEARNING_TESTS6} + otbLibSVMMachineLearningModelCanRead + ${INPUTDATA}/svm_model_image + ) + + ADD_TEST(leTuSVMMachineLearningModelCanRead ${LEARNING_TESTS6} + otbSVMMachineLearningModelCanRead + ${INPUTDATA}/svm_model_image + ) + + ADD_TEST(leTuRandomForestsMachineLearningModelCanRead ${LEARNING_TESTS6} + otbRandomForestsMachineLearningModelCanRead + ${TEMP}/RandomForestsMachineLearningModel.txt + ) + SET_TESTS_PROPERTIES(leTuRandomForestsMachineLearningModelCanRead + PROPERTIES DEPENDS leTvRandomForestsMachineLearningModel) + + ENDIF(OTB_USE_OPENCV) # Testing srcs @@ -811,6 +853,8 @@ IF(OTB_USE_OPENCV) SET(BasicLearning_SRCS6 otbLearningTests6.cxx otbTrainMachineLearningModel.cxx + otbImageClassificationFilter.cxx + otbMachineLearningModelCanRead.cxx ) ENDIF(OTB_USE_OPENCV) diff --git a/Testing/Code/Learning/otbImageClassificationFilter.cxx b/Testing/Code/Learning/otbImageClassificationFilter.cxx new file mode 100644 index 0000000000000000000000000000000000000000..0558b9f694c2b4564a9b733224b2b0ebf62e5ad6 --- /dev/null +++ b/Testing/Code/Learning/otbImageClassificationFilter.cxx @@ -0,0 +1,73 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + + =========================================================================*/ +#include "otbImageClassificationFilter.h" +#include "otbVectorImage.h" +#include "otbImage.h" +#include "otbImageFileReader.h" +#include "otbImageFileWriter.h" +#include "otbMachineLearningModelFactory.h" + +const unsigned int Dimension = 2; +typedef double PixelType; +typedef unsigned short LabeledPixelType; + +typedef otb::VectorImage<PixelType, Dimension> ImageType; +typedef otb::Image<LabeledPixelType, Dimension> LabeledImageType; +typedef otb::ImageClassificationFilter<ImageType, LabeledImageType> ClassificationFilterType; +typedef ClassificationFilterType::ModelType ModelType; +typedef ClassificationFilterType::ValueType ValueType; +typedef ClassificationFilterType::LabelType LabelType; +typedef otb::MachineLearningModelFactory<ValueType, LabelType> MachineLearningModelFactoryType; +typedef otb::ImageFileReader<ImageType> ReaderType; +typedef otb::ImageFileWriter<LabeledImageType> WriterType; + +int otbImageClassificationFilterNew(int argc, char * argv[]) +{ + ClassificationFilterType::Pointer filter = ClassificationFilterType::New(); + return EXIT_SUCCESS; +} + +int otbImageClassificationFilter(int argc, char * argv[]) +{ + const char * infname = argv[1]; + const char * modelfname = argv[2]; + const char * outfname = argv[3]; + + // Instantiating object + ClassificationFilterType::Pointer filter = ClassificationFilterType::New(); + + ReaderType::Pointer reader = ReaderType::New(); + reader->SetFileName(infname); + + ModelType::Pointer model; + + model = MachineLearningModelFactoryType::CreateMachineLearningModel(modelfname, + MachineLearningModelFactoryType::ReadMode); + + model->Load(modelfname); + + filter->SetModel(model); + filter->SetInput(reader->GetOutput()); + + WriterType::Pointer writer = WriterType::New(); + writer->SetInput(filter->GetOutput()); + writer->SetFileName(outfname); + writer->Update(); + + return EXIT_SUCCESS; +} diff --git a/Testing/Code/Learning/otbLearningTests6.cxx b/Testing/Code/Learning/otbLearningTests6.cxx index 7643f2c0b0f29f1c55d81e82ef642a05d65dd9d6..48d7f42b8cbd069bbbbf8959300dc69467374e1e 100644 --- a/Testing/Code/Learning/otbLearningTests6.cxx +++ b/Testing/Code/Learning/otbLearningTests6.cxx @@ -32,4 +32,9 @@ void RegisterTests() REGISTER_TEST(otbKNearestNeighborsMachineLearningModel); REGISTER_TEST(otbRandomForestsMachineLearningModelNew); REGISTER_TEST(otbRandomForestsMachineLearningModel); + REGISTER_TEST(otbImageClassificationFilterNew); + REGISTER_TEST(otbImageClassificationFilter); + REGISTER_TEST(otbLibSVMMachineLearningModelCanRead); + REGISTER_TEST(otbSVMMachineLearningModelCanRead); + REGISTER_TEST(otbRandomForestsMachineLearningModelCanRead); } diff --git a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx new file mode 100644 index 0000000000000000000000000000000000000000..daf4f76c4c2c4ea05f914c689201ff41f7b34254 --- /dev/null +++ b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx @@ -0,0 +1,109 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ + +#include "otbMachineLearningModel.h" +#include "otbLibSVMMachineLearningModel.h" +#include "otbSVMMachineLearningModel.h" +#include "otbRandomForestsMachineLearningModel.h" +#include <iostream> + +typedef otb::MachineLearningModel<float,short> MachineLearningModelType; +typedef MachineLearningModelType::InputValueType InputValueType; +typedef MachineLearningModelType::InputSampleType InputSampleType; +typedef MachineLearningModelType::InputListSampleType InputListSampleType; +typedef MachineLearningModelType::TargetValueType TargetValueType; +typedef MachineLearningModelType::TargetSampleType TargetSampleType; +typedef MachineLearningModelType::TargetListSampleType TargetListSampleType; + +int otbLibSVMMachineLearningModelCanRead(int argc, char* argv[]) +{ + if (argc != 2) + { + std::cerr << "Usage: " << argv[0] + << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for (int i = 1; i < argc; ++i) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename(argv[1]); + typedef otb::LibSVMMachineLearningModel<InputValueType, TargetValueType> SVMType; + SVMType::Pointer classifier = SVMType::New(); + bool lCanRead = classifier->CanReadFile(filename); + if (lCanRead == false) + { + std::cerr << "Erreur otb::LibSVMMachineLearningModel : impossible to open the file " << filename << "." << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +int otbSVMMachineLearningModelCanRead(int argc, char* argv[]) +{ + if (argc != 2) + { + std::cerr << "Usage: " << argv[0] + << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for (int i = 1; i < argc; ++i) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename(argv[1]); + typedef otb::SVMMachineLearningModel<InputValueType, TargetValueType> SVMType; + SVMType::Pointer classifier = SVMType::New(); + bool lCanRead = classifier->CanReadFile(filename); + if (lCanRead == false) + { + std::cerr << "Erreur otb::SVMMachineLearningModel : impossible to open the file " << filename << "." << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} + +int otbRandomForestsMachineLearningModelCanRead(int argc, char* argv[]) +{ + if (argc != 2) + { + std::cerr << "Usage: " << argv[0] + << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for (int i = 1; i < argc; ++i) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename(argv[1]); + typedef otb::RandomForestsMachineLearningModel<InputValueType, TargetValueType> RFType; + RFType::Pointer classifier = RFType::New(); + bool lCanRead = classifier->CanReadFile(filename); + if (lCanRead == false) + { + std::cerr << "Erreur otb::RandomForestsMachineLearningModel : impossible to open the file " << filename << "." << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +}