diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h index c5fb59f28c9fa5931351921309b5ac35cbf46b48..552880cb9a83a4e43b7e29ae75689613633e5ef0 100644 --- a/Modules/Learning/LearningBase/include/otbMachineLearningModel.h +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModel.h @@ -22,8 +22,8 @@ #define otbMachineLearningModel_h #include "itkObject.h" -#include "itkVariableLengthVector.h" #include "itkListSample.h" +#include "otbMachineLearningModelTraits.h" namespace otb { @@ -66,6 +66,7 @@ namespace otb * * \ingroup OTBLearningBase */ + template <class TInputValue, class TTargetValue, class TConfidenceValue = double > class ITK_EXPORT MachineLearningModel : public itk::Object @@ -81,22 +82,22 @@ public: /**\name Input related typedefs */ //@{ - typedef TInputValue InputValueType; - typedef itk::VariableLengthVector<InputValueType> InputSampleType; - typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; + typedef typename MLMSampleTraits<TInputValue>::ValueType InputValueType; + typedef typename MLMSampleTraits<TInputValue>::SampleType InputSampleType; + typedef itk::Statistics::ListSample<InputSampleType> InputListSampleType; //@} /**\name Target related typedefs */ //@{ - typedef TTargetValue TargetValueType; - typedef itk::FixedArray<TargetValueType,1> TargetSampleType; - typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; + typedef typename MLMTargetTraits<TTargetValue>::ValueType TargetValueType; + typedef typename MLMTargetTraits<TTargetValue>::SampleType TargetSampleType; + typedef itk::Statistics::ListSample<TargetSampleType> TargetListSampleType; //@} /**\name Confidence value typedef */ - typedef TConfidenceValue ConfidenceValueType; - typedef itk::FixedArray<ConfidenceValueType,1> ConfidenceSampleType; - typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType; + typedef typename MLMTargetTraits<TConfidenceValue>::ValueType ConfidenceValueType; + typedef typename MLMTargetTraits<TConfidenceValue>::SampleType ConfidenceSampleType; + typedef itk::Statistics::ListSample<ConfidenceSampleType> ConfidenceListSampleType; /**\name Standard macros */ //@{ diff --git a/Modules/Learning/LearningBase/include/otbMachineLearningModelTraits.h b/Modules/Learning/LearningBase/include/otbMachineLearningModelTraits.h new file mode 100644 index 0000000000000000000000000000000000000000..e9bc4cec29b9b3b45891e0265d63009bff34b76e --- /dev/null +++ b/Modules/Learning/LearningBase/include/otbMachineLearningModelTraits.h @@ -0,0 +1,104 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#ifndef otbMachineLearningModelTraits_h +#define otbMachineLearningModelTraits_h + + +#include "itkVariableLengthVector.h" +#include "itkFixedArray.h" +#include "itkIsNumber.h" +#include "itkMetaProgrammingLibrary.h" + +namespace otb +{ + +/** + * This is the struct defining the sample implementation for + * MachineLearningModel. It offers two type definitions: SampleType + * and ValueType. + * + * \tparam TInput : input sample type (can be either a scalar type or + * a VariableLenghtVector + * \tparam isNumber either TrueType or FalseType for partial + * specialization + + */ +template <typename TInput, typename isNumber> struct MLMSampleTraitsImpl; + + +/// \cond SPECIALIZATION_IMPLEMENTATION +// For Numbers +template <typename TInput> struct MLMSampleTraitsImpl<TInput, itk::mpl::TrueType> { + typedef TInput ValueType; + typedef itk::VariableLengthVector<TInput> SampleType; +}; + +// For Vectors +template <typename TInput> struct MLMSampleTraitsImpl<TInput, itk::mpl::FalseType> { + typedef typename TInput::ValueType ValueType; + typedef TInput SampleType; +}; +/// \endcond + +/** + * Simplified implementation of SampleTraits using MLMSampleTraitsImpl + */ +template <typename TInput> using MLMSampleTraits = MLMSampleTraitsImpl< TInput, typename itk::mpl::IsNumber<TInput>::Type >; + + +/** + * This is the struct defining the sample implementation for + * MachineLearningModel. It offers two type definitions: TargetType + * and ValueType. + * + * \tparam TInput : input sample type (can be either a scalar type or + * a VariableLenghtVector or a FixedArray + * \tparam isNumber either TrueType or FalseType for partial + * specialization + + */ +template <typename TInput, typename isNumber> struct MLMTargetTraitsImpl; + + +/// \cond SPECIALIZATION_IMPLEMENTATION +// For Numbers +template <typename TInput> struct MLMTargetTraitsImpl<TInput, itk::mpl::TrueType> { + typedef TInput ValueType; + typedef itk::FixedArray<TInput,1> SampleType; +}; + +// For Vectors +template <typename TInput> struct MLMTargetTraitsImpl<TInput, itk::mpl::FalseType> { + typedef typename TInput::ValueType ValueType; + typedef TInput SampleType; +}; +/// \endcond + +/** + * Simplified implementation of TargetTraits using MLMTargetTraitsImpl + */ +template <typename TInput> using MLMTargetTraits = MLMTargetTraitsImpl< TInput, typename itk::mpl::IsNumber<TInput>::Type >; + + +} // End namespace otb + +#endif diff --git a/Modules/Learning/LearningBase/test/CMakeLists.txt b/Modules/Learning/LearningBase/test/CMakeLists.txt index 74a67d44dddd939201c134f28be152281b85a7a2..d1d16c3e65801e606c6e6903538b65264a4483a6 100644 --- a/Modules/Learning/LearningBase/test/CMakeLists.txt +++ b/Modules/Learning/LearningBase/test/CMakeLists.txt @@ -29,6 +29,7 @@ otbDecisionTreeWithRealValues.cxx otbSEMClassifierNew.cxx otbDecisionTreeNew.cxx otbKMeansImageClassificationFilterNew.cxx +otbMachineLearningModelTemplates.cxx ) add_executable(otbLearningBaseTestDriver ${OTBLearningBaseTests}) diff --git a/Modules/Learning/LearningBase/test/otbMachineLearningModelTemplates.cxx b/Modules/Learning/LearningBase/test/otbMachineLearningModelTemplates.cxx new file mode 100644 index 0000000000000000000000000000000000000000..96d35e014394d6b0d6836a90ddcbf3a2303db4bf --- /dev/null +++ b/Modules/Learning/LearningBase/test/otbMachineLearningModelTemplates.cxx @@ -0,0 +1,40 @@ +/* + * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES) + * + * This file is part of Orfeo Toolbox + * + * https://www.orfeo-toolbox.org/ + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <otbMachineLearningModel.h> + +typedef otb::MachineLearningModel<float,short> MachineLearningModelType1; +typedef MachineLearningModelType1::InputValueType InputValueType1; +typedef MachineLearningModelType1::InputSampleType InputSampleType1; +typedef MachineLearningModelType1::InputListSampleType InputListSampleType1; +typedef MachineLearningModelType1::TargetValueType TargetValueType1; +typedef MachineLearningModelType1::TargetSampleType TargetSampleType1; +typedef MachineLearningModelType1::TargetListSampleType TargetListSampleType1; + +typedef otb::MachineLearningModel<float,itk::VariableLengthVector<double> > MachineLearningModelType2; +typedef MachineLearningModelType2::InputValueType InputValueType2; +typedef MachineLearningModelType2::InputSampleType InputSampleType2; +typedef MachineLearningModelType2::InputListSampleType InputListSampleType2; +typedef MachineLearningModelType2::TargetValueType TargetValueType2; +typedef MachineLearningModelType2::TargetSampleType TargetSampleType2; +typedef MachineLearningModelType2::TargetListSampleType TargetListSampleType2; + + +