diff --git a/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx b/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx index 1d8bf19db380cb80d00e598e892f9b2aa971db2d..003bf6cd1dc1a55317151cf0f8abf4f29996c255 100644 --- a/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx +++ b/Code/UtilitiesAdapters/OpenCV/otbMachineLearningModelFactory.txx @@ -25,6 +25,7 @@ #include "otbLibSVMMachineLearningModelFactory.h" #include "otbBoostMachineLearningModelFactory.h" #include "otbNeuralNetworkMachineLearningModelFactory.h" +#include "otbNormalBayesMachineLearningModelFactory.h" namespace otb @@ -99,6 +100,7 @@ MachineLearningModelFactory<TInputValue,TOutputValue> itk::ObjectFactoryBase::RegisterFactory(SVMMachineLearningModelFactory<TInputValue,TOutputValue>::New()); itk::ObjectFactoryBase::RegisterFactory(BoostMachineLearningModelFactory<TInputValue,TOutputValue>::New()); itk::ObjectFactoryBase::RegisterFactory(NeuralNetworkMachineLearningModelFactory<TInputValue,TOutputValue>::New()); + itk::ObjectFactoryBase::RegisterFactory(NormalBayesMachineLearningModelFactory<TInputValue,TOutputValue>::New()); firstTime = false; } diff --git a/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModel.h b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModel.h new file mode 100644 index 0000000000000000000000000000000000000000..2d080ee033327b5003bf91a1fd8d324255f18ac9 --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModel.h @@ -0,0 +1,99 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbNormalBayesMachineLearningModel_h +#define __otbNormalBayesMachineLearningModel_h + +#include "itkLightObject.h" +#include "itkVariableLengthVector.h" +#include "itkFixedArray.h" +#include "itkListSample.h" +#include "otbMachineLearningModel.h" + + +class CvNormalBayesClassifier; + +namespace otb +{ +template <class TInputValue, class TTargetValue> +class ITK_EXPORT NormalBayesMachineLearningModel + : public MachineLearningModel <TInputValue, TTargetValue> +{ +public: + /** Standard class typedefs. */ + typedef NormalBayesMachineLearningModel Self; + typedef MachineLearningModel<TInputValue, TTargetValue> Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + // Input related typedefs + typedef TInputValue InputValueType; + typedef itk::VariableLengthVector<InputValueType> InputSampleType; + typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; + + // Target related typedefs + typedef TTargetValue TargetValueType; + typedef itk::FixedArray<TargetValueType,1> TargetSampleType; + typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + + /** Run-time type information (and related methods). */ + itkNewMacro(Self); + itkTypeMacro(NormalBayesMachineLearningModel, itk::MachineLearningModel); + + /** Train the machine learning model */ + virtual void Train(); + + /** Predict values using the model */ + virtual TargetSampleType Predict(const InputSampleType & input) const; + + /** Save the model to file */ + virtual void Save(const std::string & filename, const std::string & name=""); + + /** Load the model from file */ + virtual void Load(const std::string & filename, const std::string & name=""); + + /** Determine the file type. Returns true if this ImageIO can read the + * file specified. */ + virtual bool CanReadFile(const std::string &); + + /** Determine the file type. Returns true if this ImageIO can write the + * file specified. */ + virtual bool CanWriteFile(const std::string &); + +protected: + /** Constructor */ + NormalBayesMachineLearningModel(); + + /** Destructor */ + virtual ~NormalBayesMachineLearningModel(); + + /** PrintSelf method */ + void PrintSelf(std::ostream& os, itk::Indent indent) const; + +private: + NormalBayesMachineLearningModel(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + + CvNormalBayesClassifier * m_NormalBayesModel; +}; +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbNormalBayesMachineLearningModel.txx" +#endif + +#endif diff --git a/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModel.txx b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModel.txx new file mode 100644 index 0000000000000000000000000000000000000000..102e9d9ab3fb105fe8c0c8d4e21870be18811883 --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModel.txx @@ -0,0 +1,152 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbNormalBayesMachineLearningModel_txx +#define __otbNormalBayesMachineLearningModel_txx + +#include "otbNormalBayesMachineLearningModel.h" +#include "otbOpenCVUtils.h" + +#include <opencv2/opencv.hpp> + +namespace otb +{ + +template <class TInputValue, class TOutputValue> +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::NormalBayesMachineLearningModel() +{ + m_NormalBayesModel = new CvNormalBayesClassifier; +} + + +template <class TInputValue, class TOutputValue> +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::~NormalBayesMachineLearningModel() +{ + delete m_NormalBayesModel; +} + +/** Train the machine learning model */ +template <class TInputValue, class TOutputValue> +void +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::Train() +{ + //convert listsample to opencv matrix + cv::Mat samples; + otb::ListSampleToMat<InputListSampleType>(this->GetInputListSample(), samples); + + cv::Mat labels; + otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels); + + m_NormalBayesModel->train(samples,labels,cv::Mat(),cv::Mat(),false); +} + +template <class TInputValue, class TOutputValue> +typename NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::TargetSampleType +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::Predict(const InputSampleType & input) const +{ + //convert listsample to Mat + cv::Mat sample; + + otb::SampleToMat<InputSampleType>(input,sample); + + cv::Mat missing = cv::Mat(1,input.Size(), CV_8U ); + missing.setTo(0); + double result = m_NormalBayesModel->predict(sample); + + TargetSampleType target; + + target[0] = static_cast<TOutputValue>(result); + + return target; +} + +template <class TInputValue, class TOutputValue> +void +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::Save(const std::string & filename, const std::string & name) +{ + if (name == "") + m_NormalBayesModel->save(filename.c_str(), 0); + else + m_NormalBayesModel->save(filename.c_str(), name.c_str()); +} + +template <class TInputValue, class TOutputValue> +void +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::Load(const std::string & filename, const std::string & name) +{ + if (name == "") + m_NormalBayesModel->load(filename.c_str(), 0); + else + m_NormalBayesModel->load(filename.c_str(), name.c_str()); +} + +template <class TInputValue, class TOutputValue> +bool +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::CanReadFile(const std::string & file) +{ + std::ifstream ifs; + ifs.open(file.c_str()); + + if(!ifs) + { + std::cerr<<"Could not read file "<<file<<std::endl; + return false; + } + + while (!ifs.eof()) + { + std::string line; + std::getline(ifs, line); + + if (line.find(CV_TYPE_NAME_ML_NBAYES) != std::string::npos) + { + std::cout<<"Reading a "<<CV_TYPE_NAME_ML_NBAYES<<" model !!!"<<std::endl; + return true; + } + } + ifs.close(); + return false; +} + +template <class TInputValue, class TOutputValue> +bool +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::CanWriteFile(const std::string & file) +{ + return false; +} + +template <class TInputValue, class TOutputValue> +void +NormalBayesMachineLearningModel<TInputValue,TOutputValue> +::PrintSelf(std::ostream& os, itk::Indent indent) const +{ + // Call superclass implementation + Superclass::PrintSelf(os,indent); +} + +} //end namespace otb + +#endif diff --git a/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModelFactory.h b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModelFactory.h new file mode 100644 index 0000000000000000000000000000000000000000..4783350efd17b7c33099dd77e97937224ff05bd0 --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModelFactory.h @@ -0,0 +1,72 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbNormalBayesMachineLearningModelFactory_h +#define __otbNormalBayesMachineLearningModelFactory_h + +#include "itkObjectFactoryBase.h" +#include "itkImageIOBase.h" + +namespace otb +{ +/** \class NormalBayesMachineLearningModelFactory + * \brief Creation d'un instance d'un objet SVMMachineLearningModel utilisant les object factory. + */ +template <class TInputValue, class TTargetValue> +class ITK_EXPORT NormalBayesMachineLearningModelFactory : public itk::ObjectFactoryBase +{ +public: + /** Standard class typedefs. */ + typedef NormalBayesMachineLearningModelFactory Self; + typedef itk::ObjectFactoryBase Superclass; + typedef itk::SmartPointer<Self> Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Class methods used to interface with the registered factories. */ + virtual const char* GetITKSourceVersion(void) const; + virtual const char* GetDescription(void) const; + + /** Method for class instantiation. */ + itkFactorylessNewMacro(Self); + + /** Run-time type information (and related methods). */ + itkTypeMacro(NormalBayesMachineLearningModelFactory, itk::ObjectFactoryBase); + + /** Register one factory of this type */ + static void RegisterOneFactory(void) + { + NormalBayesMachineLearningModelFactory::Pointer Factory = NormalBayesMachineLearningModelFactory::New(); + itk::ObjectFactoryBase::RegisterFactory(Factory); + } + +protected: + NormalBayesMachineLearningModelFactory(); + virtual ~NormalBayesMachineLearningModelFactory(); + +private: + NormalBayesMachineLearningModelFactory(const Self &); //purposely not implemented + void operator =(const Self&); //purposely not implemented + +}; + +} // end namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbNormalBayesMachineLearningModelFactory.txx" +#endif + +#endif diff --git a/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModelFactory.txx b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModelFactory.txx new file mode 100644 index 0000000000000000000000000000000000000000..5060481e59da6c0215f9d023e4a97eae7ed919ae --- /dev/null +++ b/Code/UtilitiesAdapters/OpenCV/otbNormalBayesMachineLearningModelFactory.txx @@ -0,0 +1,64 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#include "otbNormalBayesMachineLearningModelFactory.h" + +#include "itkCreateObjectFunction.h" +#include "otbNormalBayesMachineLearningModel.h" +#include "itkVersion.h" + +namespace otb +{ + +template <class TInputValue, class TOutputValue> +NormalBayesMachineLearningModelFactory<TInputValue,TOutputValue> +::NormalBayesMachineLearningModelFactory() +{ + + static std::string classOverride = std::string("otbMachineLearningModel"); + static std::string subclass = std::string("otbNormalBayesMachineLearningModel"); + + this->RegisterOverride(classOverride.c_str(), + subclass.c_str(), + "Normal Bayes ML Model", + 1, + itk::CreateObjectFunction<NormalBayesMachineLearningModel<TInputValue,TOutputValue> >::New()); +} + +template <class TInputValue, class TOutputValue> +NormalBayesMachineLearningModelFactory<TInputValue,TOutputValue> +::~NormalBayesMachineLearningModelFactory() +{ +} + +template <class TInputValue, class TOutputValue> +const char* +NormalBayesMachineLearningModelFactory<TInputValue,TOutputValue> +::GetITKSourceVersion(void) const +{ + return ITK_SOURCE_VERSION; +} + +template <class TInputValue, class TOutputValue> +const char* +NormalBayesMachineLearningModelFactory<TInputValue,TOutputValue> +::GetDescription() const +{ + return "Normal Bayes machine learning model factory"; +} + +} // end namespace otb diff --git a/Testing/Code/Learning/CMakeLists.txt b/Testing/Code/Learning/CMakeLists.txt index 686221a969aeea30f384aac13bc625b099e52e69..023b8ccc19db8385b80a09393c7e1b143e5a110c 100644 --- a/Testing/Code/Learning/CMakeLists.txt +++ b/Testing/Code/Learning/CMakeLists.txt @@ -749,6 +749,15 @@ IF(OTB_USE_OPENCV) ${INPUTDATA}/letter.scale ${TEMP}/ANNMachineLearningModel.txt ) + + ADD_TEST(leTvNormalBayesMachineLearningModel ${LEARNING_TESTS6} + #--compare-ascii ${NOTOL} + #${BASELINE_FILES}/NormalBayesMachineLearningModel.txt + #${TEMP}/NormalBayesMachineLearningModel.txt + otbNormalBayesMachineLearningModel + ${INPUTDATA}/letter.scale + ${TEMP}/NormalBayesMachineLearningModel.txt + ) ADD_TEST(leTuImageClassificationFilterNew ${LEARNING_TESTS6} otbImageClassificationFilterNew) @@ -799,6 +808,11 @@ IF(OTB_USE_OPENCV) otbNeuralNetworkMachineLearningModelCanRead ${INPUTDATA}/NeuralNetworkMachineLearningModel.txt ) + + ADD_TEST(leTuNormalBayesMachineLearningModelCanRead ${LEARNING_TESTS6} + otbNormalBayesMachineLearningModelCanRead + ${INPUTDATA}/NormalBayesMachineLearningModel.txt + ) ENDIF(OTB_USE_OPENCV) diff --git a/Testing/Code/Learning/otbLearningTests6.cxx b/Testing/Code/Learning/otbLearningTests6.cxx index 749102badb033dbfaf038d1701dc7f0096b6e5ae..5594144620a399cecb655168ca6e61ad774423b7 100644 --- a/Testing/Code/Learning/otbLearningTests6.cxx +++ b/Testing/Code/Learning/otbLearningTests6.cxx @@ -36,6 +36,8 @@ void RegisterTests() REGISTER_TEST(otbBoostMachineLearningModel); REGISTER_TEST(otbANNMachineLearningModelNew); REGISTER_TEST(otbANNMachineLearningModel); + REGISTER_TEST(otbNormalBayesMachineLearningModelNew); + REGISTER_TEST(otbNormalBayesMachineLearningModel); REGISTER_TEST(otbImageClassificationFilterNew); REGISTER_TEST(otbImageClassificationFilter); REGISTER_TEST(otbLibSVMMachineLearningModelCanRead); @@ -43,4 +45,5 @@ void RegisterTests() REGISTER_TEST(otbRandomForestsMachineLearningModelCanRead); REGISTER_TEST(otbBoostMachineLearningModelCanRead); REGISTER_TEST(otbNeuralNetworkMachineLearningModelCanRead); + REGISTER_TEST(otbNormalBayesMachineLearningModelCanRead); } diff --git a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx index 5cb62adee181bdc7885239442e0520b3a6730266..578fab35924458a8bf67a4d4bd2f7750298e8c35 100644 --- a/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx +++ b/Testing/Code/Learning/otbMachineLearningModelCanRead.cxx @@ -22,6 +22,7 @@ #include "otbRandomForestsMachineLearningModel.h" #include "otbBoostMachineLearningModel.h" #include "otbNeuralNetworkMachineLearningModel.h" +#include "otbNormalBayesMachineLearningModel.h" #include <iostream> typedef otb::MachineLearningModel<float,short> MachineLearningModelType; @@ -162,4 +163,29 @@ int otbNeuralNetworkMachineLearningModelCanRead(int argc, char* argv[]) return EXIT_SUCCESS; } +int otbNormalBayesMachineLearningModelCanRead(int argc, char* argv[]) +{ + if (argc != 2) + { + std::cerr << "Usage: " << argv[0] + << "<model>" << std::endl; + std::cerr << "Called here with " << argc << " arguments\n"; + for (int i = 1; i < argc; ++i) + { + std::cerr << " - " << argv[i] << "\n"; + } + return EXIT_FAILURE; + } + std::string filename(argv[1]); + typedef otb::NormalBayesMachineLearningModel<InputValueType, TargetValueType> NormalBayesType; + NormalBayesType::Pointer classifier = NormalBayesType::New(); + bool lCanRead = classifier->CanReadFile(filename); + if (lCanRead == false) + { + std::cerr << "Erreur otb::NormalBayesMachineLearningModel : impossible to open the file " << filename << "." << std::endl; + return EXIT_FAILURE; + } + + return EXIT_SUCCESS; +} diff --git a/Testing/Code/Learning/otbTrainMachineLearningModel.cxx b/Testing/Code/Learning/otbTrainMachineLearningModel.cxx index ba2e3f29687e16ef8a40e3da5187aaf99cf8c880..3f11e9991ecc082de62d0a0bf102ada754047d98 100644 --- a/Testing/Code/Learning/otbTrainMachineLearningModel.cxx +++ b/Testing/Code/Learning/otbTrainMachineLearningModel.cxx @@ -27,6 +27,7 @@ #include "otbRandomForestsMachineLearningModel.h" #include "otbBoostMachineLearningModel.h" #include "otbNeuralNetworkMachineLearningModel.h" +#include "otbNormalBayesMachineLearningModel.h" #include "otbConfusionMatrixCalculator.h" @@ -478,6 +479,58 @@ int otbANNMachineLearningModel(int argc, char * argv[]) return EXIT_SUCCESS; } +int otbNormalBayesMachineLearningModelNew(int argc, char * argv[]) +{ + typedef otb::NormalBayesMachineLearningModel<InputValueType,TargetValueType> NormalBayesType; + NormalBayesType::Pointer classifier = NormalBayesType::New(); + return EXIT_SUCCESS; +} + +int otbNormalBayesMachineLearningModel(int argc, char * argv[]) +{ + if (argc != 3 ) + { + std::cout<<"Wrong number of arguments "<<std::endl; + std::cout<<"Usage : sample file, output file "<<std::endl; + return EXIT_FAILURE; + } + + typedef otb::NormalBayesMachineLearningModel<InputValueType, TargetValueType> NormalBayesType; + + InputListSampleType::Pointer samples = InputListSampleType::New(); + TargetListSampleType::Pointer labels = TargetListSampleType::New(); + TargetListSampleType::Pointer predicted = TargetListSampleType::New(); + + if(!ReadDataFile(argv[1],samples,labels)) + { + std::cout<<"Failed to read samples file "<<argv[1]<<std::endl; + return EXIT_FAILURE; + } + + NormalBayesType::Pointer classifier = NormalBayesType::New(); + classifier->SetInputListSample(samples); + classifier->SetTargetListSample(labels); + classifier->Train(); + + classifier->SetTargetListSample(predicted); + classifier->PredictAll(); + + classifier->Save(argv[2]); + + ConfusionMatrixCalculatorType::Pointer cmCalculator = ConfusionMatrixCalculatorType::New(); + + cmCalculator->SetProducedLabels(predicted); + cmCalculator->SetReferenceLabels(labels); + cmCalculator->Compute(); + + std::cout<<"Confusion matrix: "<<std::endl; + std::cout<<cmCalculator->GetConfusionMatrix()<<std::endl; + std::cout<<"Kappa: "<<cmCalculator->GetKappaIndex()<<std::endl; + std::cout<<"Overall Accuracy: "<<cmCalculator->GetOverallAccuracy()<<std::endl; + + return EXIT_SUCCESS; +} +