diff --git a/Code/Learning/otbListSampleGenerator.h b/Code/Learning/otbListSampleGenerator.h index 840f351caf3375917bb11c6fb4461f955f465eff..474186622a12141833acee5105f48729aba9e80b 100644 --- a/Code/Learning/otbListSampleGenerator.h +++ b/Code/Learning/otbListSampleGenerator.h @@ -117,11 +117,21 @@ public: itkGetStringMacro(ClassKey); itkSetStringMacro(ClassKey); + itkGetConstMacro(ClassMinSize, double); + itkGetObjectMacro(TrainingListSample, ListSampleType); itkGetObjectMacro(TrainingListLabel, ListLabelType); itkGetObjectMacro(ValidationListSample, ListSampleType); itkGetObjectMacro(ValidationListLabel, ListLabelType); + std::map<ClassLabelType, double> GetClassesSize() const + { + return m_ClassesSize; + } + + + void GenerateClassStatistics(); + protected: ListSampleGenerator(); virtual ~ListSampleGenerator() {} @@ -134,7 +144,6 @@ private: ListSampleGenerator(const Self&); //purposely not implemented void operator=(const Self&); //purposely not implemented - void GenerateClassStatistics(); void ComputeClassSelectionProbability(); long int m_MaxTrainingSize; // number of training samples (-1 = no limit) @@ -144,12 +153,13 @@ private: unsigned short m_NumberOfClasses; std::string m_ClassKey; + double m_ClassMinSize; ListSamplePointerType m_TrainingListSample; ListLabelPointerType m_TrainingListLabel; ListSamplePointerType m_ValidationListSample; ListLabelPointerType m_ValidationListLabel; - + std::map<ClassLabelType, double> m_ClassesSize; std::map<ClassLabelType, double> m_ClassesProbTraining; diff --git a/Code/Learning/otbListSampleGenerator.txx b/Code/Learning/otbListSampleGenerator.txx index 2c8ed69526cd138e45b3030e80e27bb477c94d87..d7848bc1163b01be7816f101f9d00bad108f4bda 100644 --- a/Code/Learning/otbListSampleGenerator.txx +++ b/Code/Learning/otbListSampleGenerator.txx @@ -32,7 +32,8 @@ ListSampleGenerator<TImage, TVectorData> m_MaxValidationSize(-1), m_ValidationTrainingProportion(0.0), m_NumberOfClasses(0), - m_ClassKey("Class") + m_ClassKey("Class"), + m_ClassMinSize(-1) { this->SetNumberOfRequiredInputs(2); this->SetNumberOfRequiredOutputs(1); @@ -214,8 +215,20 @@ ListSampleGenerator<TImage,TVectorData> ++itVector; } - m_NumberOfClasses = m_ClassesSize.size(); + std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); + double minSize = itmap->second; + ++itmap; + while(itmap != m_ClassesSize.end()) + { + if (minSize > itmap->second) + { + minSize = itmap->second; + } + ++itmap; + } + m_ClassMinSize = minSize; + m_NumberOfClasses = m_ClassesSize.size(); } template < class TImage, class TVectorData > @@ -226,7 +239,7 @@ ListSampleGenerator<TImage,TVectorData> m_ClassesProbTraining.clear(); m_ClassesProbValidation.clear(); - //Go throught the classes size to find the smallest one + //Go through the classes size to find the smallest one double minSizeTraining = -1; for (std::map<ClassLabelType, double>::const_iterator itmap = m_ClassesSize.begin(); itmap != m_ClassesSize.end(); ++itmap) {