diff --git a/Code/Learning/otbListSampleToBalancedListSampleFilter.h b/Code/Learning/otbListSampleToBalancedListSampleFilter.h new file mode 100644 index 0000000000000000000000000000000000000000..bf0c8d507f3bf56cf6461d567b63b279baf63c97 --- /dev/null +++ b/Code/Learning/otbListSampleToBalancedListSampleFilter.h @@ -0,0 +1,150 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbListSampleToBalancedListSampleFilter_h +#define __otbListSampleToBalancedListSampleFilter_h + +#include "otbListSampleToListSampleFilter.h" +#include "otbGaussianAdditiveNoiseSampleListFilter.h" +#include "itkDataObjectDecorator.h" +#include "otbMacro.h" + +namespace otb { +namespace Statistics { + +/** \class ListSampleToBalancedListSampleFilter + * \brief This class generate a balanced ListSample in order to have + * fair distribution of learning samples. + * + * The maximum number of samples with same labels are first + * computed. This maximum number by the m_BalacingFactor determines + * the final number of samples belonging to each label. + * + * + * Mean and Variance are set via the methods SetMean() and SetVariance(). + * + * \sa ListSampleToListSampleFilter, GaussianAdditiveNoiseSampleListFilter + */ +template < class TInputSampleList, + class TLabelSampleList, + class TOutputSampleList = TInputSampleList > +class ITK_EXPORT ListSampleToBalancedListSampleFilter : + public otb::Statistics::ListSampleToListSampleFilter<TInputSampleList, + TOutputSampleList> +{ +public: + /** Standard class typedefs */ + typedef ListSampleToBalancedListSampleFilter Self; + typedef otb::Statistics::ListSampleToListSampleFilter + <TInputSampleList,TOutputSampleList> Superclass; + typedef itk::SmartPointer< Self > Pointer; + typedef itk::SmartPointer<const Self> ConstPointer; + + /** Run-time type information (and related methods). */ + itkTypeMacro(ListSampleToBalancedListSampleFilter,otb::Statistics::ListSampleToListSampleFilter); + + /** Method for creation through the object factory. */ + itkNewMacro(Self); + + /** InputSampleList typedefs */ + typedef TInputSampleList InputSampleListType; + typedef typename InputSampleListType::Pointer InputSampleListPointer; + typedef typename InputSampleListType::ConstPointer InputSampleListConstPointer; + typedef typename InputSampleListType::MeasurementVectorType InputMeasurementVectorType; + typedef typename InputMeasurementVectorType::ValueType InputValueType; + + /** LabelSampleList typedefs */ + typedef TLabelSampleList LabelSampleListType; + typedef typename LabelSampleListType::Pointer LabelSampleListPointer; + typedef typename LabelSampleListType::ConstPointer LabelSampleListConstPointer; + typedef typename LabelSampleListType::MeasurementVectorType LabelMeasurementVectorType; + typedef typename LabelMeasurementVectorType::ValueType LabelValueType; + typedef itk::DataObjectDecorator< LabelSampleListType > LabelSampleListObjectType; + + /** OutputSampleList typedefs */ + typedef TOutputSampleList OutputSampleListType; + typedef typename OutputSampleListType::Pointer OutputSampleListPointer; + typedef typename OutputSampleListType::ConstPointer OutputSampleListConstPointer; + typedef typename OutputSampleListType::MeasurementVectorType OutputMeasurementVectorType; + typedef typename OutputMeasurementVectorType::ValueType OutputValueType; + + /** Input & Output sample list as data object */ + typedef typename Superclass::InputSampleListObjectType InputSampleListObjectType; + typedef typename Superclass::OutputSampleListObjectType OutputSampleListObjectType; + + /** Filter adding noise to a ListSample */ + typedef otb::Statistics::GaussianAdditiveNoiseSampleListFilter + <InputSampleListType,OutputSampleListType> GaussianAdditiveNoiseType; + typedef typename GaussianAdditiveNoiseType::Pointer GaussianAdditiveNoisePointerType; + + /** Get/Set the label sample list */ + void SetInputLabel( const LabelSampleListType * label ); + void SetInputLabel( const LabelSampleListObjectType * labelPtr ); + + /** Returns the label sample list */ + const LabelSampleListType * GetLabelSampleList() const; + + /** Returns the label sample list as a data object */ + const LabelSampleListObjectType * GetInputLabel() const; + + /** Set/Get the mean for the white gaussian noise to generate */ + otbSetObjectMemberMacro(AddGaussianNoiseFilter,Mean,double); + otbGetObjectMemberConstMacro(AddGaussianNoiseFilter,Mean,double); + + /** Set/Get the variance for the white gaussian noise to generate */ + otbSetObjectMemberMacro(AddGaussianNoiseFilter,Variance,double); + otbGetObjectMemberConstMacro(AddGaussianNoiseFilter,Variance,double); + + /** Set/Get the multiplicative factor : this value is used to + * determine the maximum number of samples in each label in order + * to reach a balanced output ListSample + */ + itkSetMacro(BalancingFactor,unsigned int); + itkGetMacro(BalancingFactor,unsigned int); + +protected: + /** This method causes the filter to generate its output. */ + virtual void GenerateData(); + + /** In order to respect the fair data principle, the number of samples for + * each label must be the same. This method computes the label that + * have the higher number of sample. + */ + void ComputeMaxSampleFrequency(); + + ListSampleToBalancedListSampleFilter(); + virtual ~ListSampleToBalancedListSampleFilter() {} + void PrintSelf(std::ostream& os, itk::Indent indent) const; + +private: + ListSampleToBalancedListSampleFilter(const Self&); //purposely not implemented + void operator=(const Self&); //purposely not implemented + + GaussianAdditiveNoisePointerType m_AddGaussianNoiseFilter; + std::vector<unsigned int> m_MultiplicativeCoefficient; + unsigned int m_BalancingFactor; + +}; // end of class ListSampleToBalancedListSampleFilter + +} // end of namespace Statistics +} // end of namespace otb + +#ifndef OTB_MANUAL_INSTANTIATION +#include "otbListSampleToBalancedListSampleFilter.txx" +#endif + +#endif diff --git a/Code/Learning/otbListSampleToBalancedListSampleFilter.txx b/Code/Learning/otbListSampleToBalancedListSampleFilter.txx new file mode 100644 index 0000000000000000000000000000000000000000..45610b7ce9d85cc8ecb654043be22e5c08af9a6e --- /dev/null +++ b/Code/Learning/otbListSampleToBalancedListSampleFilter.txx @@ -0,0 +1,257 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ +#ifndef __otbListSampleToBalancedListSampleFilter_txx +#define __otbListSampleToBalancedListSampleFilter_txx + +#include "otbListSampleToBalancedListSampleFilter.h" +#include "itkProgressReporter.h" +#include "itkHistogram.h" +#include "itkNumericTraits.h" + +namespace otb { +namespace Statistics { + +// constructor +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::ListSampleToBalancedListSampleFilter() +{ + this->SetNumberOfRequiredInputs(2); + + m_AddGaussianNoiseFilter = GaussianAdditiveNoiseType::New(); + m_BalancingFactor = 5; +} + +// Method to set the SampleList +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +void +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::SetInputLabel( const LabelSampleListType * label ) +{ + typename LabelSampleListObjectType::Pointer labelPtr = LabelSampleListObjectType::New(); + labelPtr->Set(label); + this->SetInputLabel(labelPtr); +} + +// Method to set the SampleList as DataObject +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +void +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::SetInputLabel( const LabelSampleListObjectType * labelPtr ) +{ + // Process object is not const-correct so the const_cast is required here + Superclass::ProcessObject::SetNthInput(1, + const_cast< LabelSampleListObjectType* >( labelPtr ) ); +} + +// Method to get the SampleList as DataObject +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +const typename ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::LabelSampleListObjectType * +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::GetInputLabel() const +{ + if (this->GetNumberOfInputs() < 2) + { + return 0; + } + + return static_cast<const LabelSampleListObjectType* > + (Superclass::ProcessObject::GetInput(1) ); +} + +// Method to get the SampleList +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +const typename ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::LabelSampleListType * +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::GetLabelSampleList() const +{ + if (this->GetNumberOfInputs() < 2) + { + return 0; + } + + typename LabelSampleListObjectType::ConstPointer dataObjectPointer = static_cast<const LabelSampleListObjectType * > + (Superclass::ProcessObject::GetInput(1) ); + return dataObjectPointer->Get(); +} + +// Get the max sample number having the same label +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +void +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::ComputeMaxSampleFrequency() +{ + // Iterate on the labelSampleList to get the min and max label + LabelValueType maxLabel = itk::NumericTraits<LabelValueType>::min(); + + // Number of bins to add to the histogram + typename LabelSampleListType::ConstPointer labelPtr = this->GetLabelSampleList(); + typename LabelSampleListType::ConstIterator labelIt = labelPtr->Begin(); + + while(labelIt != labelPtr->End()) + { + // Get the current label sample + LabelMeasurementVectorType currentInputMeasurement = labelIt.GetMeasurementVector(); + + if (currentInputMeasurement[0] > maxLabel) + maxLabel = currentInputMeasurement[0]; + + ++labelIt; + } + + // Prepare histogram with dimension 1 : default template parameters + typedef typename itk::Statistics::Histogram<unsigned int> HistogramType; + typename HistogramType::Pointer histogram = HistogramType::New(); + typename HistogramType::SizeType size; + size.Fill(maxLabel +1); + histogram->Initialize(size); + + labelIt = labelPtr->Begin(); + while (labelIt != labelPtr->End()) + { + // Get the current label sample + LabelMeasurementVectorType currentInputMeasurement = labelIt.GetMeasurementVector(); + histogram->IncreaseFrequency(currentInputMeasurement[0], 1.); + ++labelIt; + } + + // Iterate through the histogram to get the maximum + unsigned int maxvalue = 0; + HistogramType::Iterator iter = histogram->Begin(); + + while ( iter != histogram->End() ) + { + if( static_cast<unsigned int>(iter.GetFrequency()) > maxvalue ) + maxvalue = static_cast<unsigned int>(iter.GetFrequency()); + ++iter; + } + + // Number of sample per label to reach in order to have a balanced + // ListSample + unsigned int balancedFrequency = m_BalancingFactor * maxvalue; + + // Guess how much noised samples must be added per sample to get + // a balanced ListSample : Computed using the + // - Frequency of each label (stored in the histogram) + // - The value maxvalue by m_BalancingFactor + // The std::vector below stores the multiplicative factor + iter = histogram->Begin(); + while ( iter != histogram->End() ) + { + unsigned int coeff = static_cast<unsigned int>(balancedFrequency/iter.GetFrequency()); + m_MultiplicativeCoefficient.push_back(coeff); + ++iter; + } +} + +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +void +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::GenerateData() +{ + // Get the how much each sample must be expanded + this->ComputeMaxSampleFrequency(); + + // Retrieve input and output pointers + typename InputSampleListObjectType::ConstPointer inputPtr = this->GetInput(); + typename LabelSampleListObjectType::ConstPointer labelPtr = this->GetInputLabel(); + typename OutputSampleListObjectType::Pointer outputPtr = this->GetOutput(); + + // Retrieve the ListSample + InputSampleListConstPointer inputSampleListPtr = inputPtr->Get(); + LabelSampleListConstPointer labelSampleListPtr = labelPtr->Get(); + OutputSampleListPointer outputSampleListPtr = const_cast<OutputSampleListType *>(outputPtr->Get()); + + // Clear any previous output + outputSampleListPtr->Clear(); + + typename InputSampleListType::ConstIterator inputIt = inputSampleListPtr->Begin(); + typename LabelSampleListType::ConstIterator labelIt = labelSampleListPtr->Begin(); + + // Set-up progress reporting + itk::ProgressReporter progress(this,0,inputSampleListPtr->Size()); + + // Iterate on the InputSampleList + while(inputIt != inputSampleListPtr->End() && labelIt != labelSampleListPtr->End()) + { + // Retrieve current input sample + InputMeasurementVectorType currentInputMeasurement = inputIt.GetMeasurementVector(); + // Retrieve the current label + LabelMeasurementVectorType currentLabelMeasurement = labelIt.GetMeasurementVector(); + + // Build a temporary ListSample wiht the current + // measurement vector to generate noised versions of this + // measurement vector + InputSampleListPointer tempListSample = InputSampleListType::New(); + tempListSample->PushBack(currentInputMeasurement); + + // Get how many times we have to noise this sample + unsigned int iterations = m_MultiplicativeCoefficient[currentLabelMeasurement[0]]; + + // Noising filter + GaussianAdditiveNoisePointerType noisingFilter = GaussianAdditiveNoiseType::New(); + noisingFilter->SetInput(tempListSample); + noisingFilter->SetNumberOfIteration(iterations); + noisingFilter->Update(); + + // Build current output sample + OutputMeasurementVectorType currentOutputMeasurement; + currentOutputMeasurement.SetSize(currentInputMeasurement.GetSize()); + + // Cast the current sample in outputSampleValue + for(unsigned int idx = 0;idx < inputSampleListPtr->GetMeasurementVectorSize();++idx) + currentOutputMeasurement[idx] = static_cast<OutputValueType>(currentInputMeasurement[idx]); + + // Add the current input casted sample to the output SampleList + outputSampleListPtr->PushBack(currentOutputMeasurement); + + // Add the noised versions of the current sample to OutputSampleList + typename OutputSampleListType::ConstIterator tempIt = noisingFilter->GetOutput()->Get()->Begin(); + + while(tempIt != noisingFilter->GetOutput()->Get()->End()) + { + // Get the noised sample of the current measurement vector + OutputMeasurementVectorType currentTempMeasurement = tempIt.GetMeasurementVector(); + // Add to output SampleList + outputSampleListPtr->PushBack(currentTempMeasurement); + ++tempIt; + } + + // Update progress + progress.CompletedPixel(); + + ++inputIt; + ++ labelIt; + } +} + +template < class TInputSampleList, class TLabelSampleList, class TOutputSampleList > +void +ListSampleToBalancedListSampleFilter<TInputSampleList,TLabelSampleList,TOutputSampleList> +::PrintSelf(std::ostream& os, itk::Indent indent) const +{ + // Call superclass implementation + Superclass::PrintSelf(os,indent); +} + +} // End namespace Statistics +} // End namespace otb + +#endif diff --git a/Testing/Code/Learning/CMakeLists.txt b/Testing/Code/Learning/CMakeLists.txt index a2b7718c99da8e1173aa556c7a5885fd344fa8dd..31c62d6df081e43e7c22394753e6ec9edc2e16aa 100644 --- a/Testing/Code/Learning/CMakeLists.txt +++ b/Testing/Code/Learning/CMakeLists.txt @@ -603,6 +603,34 @@ ${TEMP}/leTvConcatenateSampleListFilterOutput.txt 0 -1 ) +#ListSampleToBalancedListSampleFilterNew tests ---------- +ADD_TEST(leTuListSampleToBalancedListSampleFilterNew ${LEARNING_TESTS4} +otbListSampleToBalancedListSampleFilterNew) + +ADD_TEST(leTvListSampleToBalancedListSampleFilter ${LEARNING_TESTS4} +--compare-ascii ${NOTOL} +${BASELINE_FILES}/leTvListSampleToBalancedListSampleFilterOutput.txt + ${TEMP}/leTvListSampleToBalancedListSampleFilterOutput.txt +otbListSampleToBalancedListSampleFilter +${TEMP}/leTvListSampleToBalancedListSampleFilterOutput.txt + 2 +-1 -3 0 # The third element is the label of the SampleList + 1 2 1 +-2 -5 0 +-1 -3 1 + 0 -1 1 +-3 1 1 +-5 2 1 + 2 1 1 + 2 8 1 + 1 -4 0 +-1 5 4 + 2 5 1 + 0 -5 0 + 1 -1 2 +) + + # Testing srcs SET(BasicLearning_SRCS1 otbLearningTests1.cxx @@ -671,6 +699,7 @@ otbSVMValidation.cxx otbShiftScaleSampleListFilter.cxx otbGaussianAdditiveNoiseSampleListFilter.cxx otbConcatenateSampleListFilter.cxx +otbListSampleToBalancedListSampleFilter.cxx ) OTB_ADD_EXECUTABLE(otbLearningTests1 "${BasicLearning_SRCS1}" "OTBLearning;OTBIO;OTBTesting") diff --git a/Testing/Code/Learning/otbLearningTests4.cxx b/Testing/Code/Learning/otbLearningTests4.cxx index 9e6e85eb00e723db0d84cf3d9a0d7295eb9eb8f4..0ecc981927cfdfdefec21f21990246363dadcf18 100644 --- a/Testing/Code/Learning/otbLearningTests4.cxx +++ b/Testing/Code/Learning/otbLearningTests4.cxx @@ -43,4 +43,6 @@ void RegisterTests() REGISTER_TEST(otbGaussianAdditiveNoiseSampleListFilter); REGISTER_TEST(otbConcatenateSampleListFilterNew); REGISTER_TEST(otbConcatenateSampleListFilter); + REGISTER_TEST(otbListSampleToBalancedListSampleFilterNew); + REGISTER_TEST(otbListSampleToBalancedListSampleFilter); } diff --git a/Testing/Code/Learning/otbListSampleToBalancedListSampleFilter.cxx b/Testing/Code/Learning/otbListSampleToBalancedListSampleFilter.cxx new file mode 100644 index 0000000000000000000000000000000000000000..4f680a1810be99e27713c11b93fdc6ec86b15df6 --- /dev/null +++ b/Testing/Code/Learning/otbListSampleToBalancedListSampleFilter.cxx @@ -0,0 +1,109 @@ +/*========================================================================= + + Program: ORFEO Toolbox + Language: C++ + Date: $Date$ + Version: $Revision$ + + + Copyright (c) Centre National d'Etudes Spatiales. All rights reserved. + See OTBCopyright.txt for details. + + + This software is distributed WITHOUT ANY WARRANTY; without even + the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR + PURPOSE. See the above copyright notices for more information. + +=========================================================================*/ + +#if defined(_MSC_VER) +#pragma warning ( disable : 4786 ) +#endif + +#include "itkListSample.h" +#include "otbListSampleToBalancedListSampleFilter.h" +#include <fstream> + +typedef itk::VariableLengthVector<double> DoubleSampleType; +typedef itk::Statistics::ListSample<DoubleSampleType> DoubleSampleListType; + +typedef itk::VariableLengthVector<unsigned int> IntegerSampleType; +typedef itk::Statistics::ListSample<IntegerSampleType> IntegerSampleListType; + +typedef itk::VariableLengthVector<float> FloatSampleType; +typedef itk::Statistics::ListSample<FloatSampleType> FloatSampleListType; + +typedef otb::Statistics::ListSampleToBalancedListSampleFilter +<FloatSampleListType,IntegerSampleListType,DoubleSampleListType> BalancingFilterType; + + +int otbListSampleToBalancedListSampleFilterNew(int argc, char * argv[]) +{ + BalancingFilterType::Pointer filter = BalancingFilterType::New(); + return EXIT_SUCCESS; +} + +int otbListSampleToBalancedListSampleFilter(int argc, char * argv[]) +{ + // Compute the number of samples + const char * outfname = argv[1]; + unsigned int sampleSize = atoi(argv[2]); + unsigned int nbSamples = (argc-3)/(sampleSize+1); // +1 cause the + // label is added + // in the commandline + + IntegerSampleListType::Pointer labelSampleList = IntegerSampleListType::New(); + labelSampleList->SetMeasurementVectorSize(1); + + FloatSampleListType::Pointer inputSampleList = FloatSampleListType::New(); + inputSampleList->SetMeasurementVectorSize(sampleSize); + + BalancingFilterType::Pointer filter = BalancingFilterType::New(); + filter->SetInput(inputSampleList); + filter->SetInputLabel(labelSampleList); + + // Input Sample + FloatSampleType sample(sampleSize); + IntegerSampleType label(1); + + unsigned int index = 3; + + std::ofstream ofs(outfname); + + ofs<<"Sample size: "<<sampleSize<<std::endl; + ofs<<"Nb samples : "<<nbSamples<<std::endl; + + // InputSampleList and LabelSampleList + for(unsigned int sampleId = 0; sampleId<nbSamples;++sampleId) + { + for(unsigned int i = 0; i<sampleSize;++i) + { + sample[i]=atof(argv[index]); + ++index; + } + label[0]= atof(argv[index++]); + + //std::cout<<sample<<std::endl; + //std::cout<<label<<std::endl; + ofs<<sample<<std::endl; + ofs<<label<<std::endl; + inputSampleList->PushBack(sample); + labelSampleList->PushBack(label); + } + + filter->Update(); + + DoubleSampleListType::ConstIterator outIt = filter->GetOutputSampleList()->Begin(); + + ofs<<"Output samples: "<<std::endl; + + while(outIt != filter->GetOutputSampleList()->End()) + { + ofs<<outIt.GetMeasurementVector()<<std::endl; + ++outIt; + } + + ofs.close(); + + return EXIT_SUCCESS; +}