Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
otb
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
David Youssefi
otb
Commits
70d4ebe2
Commit
70d4ebe2
authored
13 years ago
by
Jonathan Guinet
Browse files
Options
Downloads
Plain Diff
MRG
parents
e50bc0fc
7304f4b5
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
Applications/Classification/otbKMeansClassification.cxx
+351
-0
351 additions, 0 deletions
Applications/Classification/otbKMeansClassification.cxx
with
351 additions
and
0 deletions
Applications/Classification/otbKMeansClassification.cxx
0 → 100644
+
351
−
0
View file @
70d4ebe2
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include
"otbWrapperApplication.h"
#include
"otbWrapperApplicationFactory.h"
#include
"otbVectorImage.h"
#include
"otbImage.h"
#include
"itkEuclideanDistance.h"
#include
"itkImageRegionSplitter.h"
#include
"otbStreamingTraits.h"
#include
"otbKMeansImageClassificationFilter.h"
#include
"itkImageRegionConstIterator.h"
#include
"itkListSample.h"
#include
"itkWeightedCentroidKdTreeGenerator.h"
#include
"itkKdTreeBasedKmeansEstimator.h"
#include
"itkMersenneTwisterRandomVariateGenerator.h"
#include
"itkCastImageFilter.h"
#include
"otbMultiToMonoChannelExtractROI.h"
namespace
otb
{
namespace
Wrapper
{
typedef
otb
::
Image
<
FloatVectorImageType
::
InternalPixelType
,
2
>
ImageReaderType
;
typedef
UInt8ImageType
LabeledImageType
;
typedef
ImageReaderType
::
PixelType
PixelType
;
typedef
itk
::
FixedArray
<
PixelType
,
108
>
SampleType
;
typedef
itk
::
Statistics
::
ListSample
<
SampleType
>
ListSampleType
;
typedef
itk
::
Statistics
::
WeightedCentroidKdTreeGenerator
<
ListSampleType
>
TreeGeneratorType
;
typedef
TreeGeneratorType
::
KdTreeType
TreeType
;
typedef
itk
::
Statistics
::
KdTreeBasedKmeansEstimator
<
TreeType
>
EstimatorType
;
typedef
itk
::
CastImageFilter
<
FloatImageListType
,
FloatImageType
>
CastMaskFilterType
;
typedef
otb
::
MultiToMonoChannelExtractROI
<
FloatVectorImageType
::
InternalPixelType
,
LabeledImageType
::
InternalPixelType
>
ExtractorType
;
typedef
otb
::
StreamingTraits
<
FloatVectorImageType
>
StreamingTraitsType
;
typedef
itk
::
ImageRegionSplitter
<
2
>
SplitterType
;
typedef
ImageReaderType
::
RegionType
RegionType
;
typedef
itk
::
ImageRegionConstIterator
<
FloatVectorImageType
>
IteratorType
;
typedef
itk
::
ImageRegionConstIterator
<
LabeledImageType
>
LabeledIteratorType
;
typedef
otb
::
KMeansImageClassificationFilter
<
FloatVectorImageType
,
LabeledImageType
,
108
>
ClassificationFilterType
;
class
KMeansClassification
:
public
Application
{
public:
/** Standard class typedefs. */
typedef
KMeansClassification
Self
;
typedef
Application
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
/** Standard macro */
itkNewMacro
(
Self
);
itkTypeMacro
(
KMeansClassification
,
otb
::
Application
);
private:
KMeansClassification
()
{
SetName
(
"KMeansClassification"
);
SetDescription
(
"Unsupervised KMeans image classification."
);
}
virtual
~
KMeansClassification
()
{
}
void
DoCreateParameters
()
{
AddParameter
(
ParameterType_InputImage
,
"in"
,
"Input Image"
);
AddParameter
(
ParameterType_OutputImage
,
"out"
,
"Output Image"
);
AddParameter
(
ParameterType_InputImage
,
"vm"
,
"Validity Mask"
);
AddParameter
(
ParameterType_Int
,
"ts"
,
"Size of the training set"
);
SetParameterInt
(
"ts"
,
100
);
AddParameter
(
ParameterType_Float
,
"tp"
,
"Probability for a sample to be selected in the training set"
);
SetParameterFloat
(
"tp"
,
0.5
);
AddParameter
(
ParameterType_Int
,
"nc"
,
"Number of classes"
);
SetParameterInt
(
"nc"
,
3
);
AddParameter
(
ParameterType_Float
,
"cp"
,
"Probability for a pixel to be selected as an initial class centroid"
);
SetParameterFloat
(
"cp"
,
0.8
);
AddParameter
(
ParameterType_Int
,
"sl"
,
"Number of lines for each streaming block"
);
SetParameterInt
(
"sl"
,
1000
);
}
void
DoUpdateParameters
()
{
// Nothing to do here : all parameters are independent
}
void
DoExecute
()
{
GetLogger
()
->
Debug
(
"Entering DoExecute
\n
"
);
// initiating random number generation
itk
::
Statistics
::
MersenneTwisterRandomVariateGenerator
::
Pointer
randomGen
=
itk
::
Statistics
::
MersenneTwisterRandomVariateGenerator
::
New
();
m_InImage
=
GetParameterImage
(
"in"
);
m_Extractor
=
ExtractorType
::
New
();
m_Extractor
->
SetInput
(
GetParameterImage
(
"vm"
));
m_Extractor
->
SetChannel
(
1
);
m_Extractor
->
UpdateOutputInformation
();
LabeledImageType
::
Pointer
maskImage
=
m_Extractor
->
GetOutput
();
std
::
ostringstream
message
(
""
);
const
unsigned
int
nbsamples
=
GetParameterInt
(
"ts"
);
const
double
trainingProb
=
GetParameterFloat
(
"tp"
);
const
double
initProb
=
GetParameterFloat
(
"cp"
);
const
unsigned
int
nb_classes
=
GetParameterInt
(
"nc"
);
const
unsigned
int
nbLinesForStreaming
=
GetParameterInt
(
"sl"
);
/*******************************************/
/* Sampling data */
/*******************************************/
GetLogger
()
->
Info
(
"-- SAMPLING DATA --"
);
// Update input images information
m_InImage
->
UpdateOutputInformation
();
maskImage
->
UpdateOutputInformation
();
if
(
m_InImage
->
GetLargestPossibleRegion
()
!=
maskImage
->
GetLargestPossibleRegion
())
{
GetLogger
()
->
Error
(
"Mask image and input image have different sizes."
);
}
RegionType
largestRegion
=
m_InImage
->
GetLargestPossibleRegion
();
// Setting up local streaming capabilities
SplitterType
::
Pointer
splitter
=
SplitterType
::
New
();
unsigned
int
numberOfStreamDivisions
=
StreamingTraitsType
::
CalculateNumberOfStreamDivisions
(
m_InImage
,
largestRegion
,
splitter
,
otb
::
SET_BUFFER_NUMBER_OF_LINES
,
0
,
0
,
nbLinesForStreaming
);
message
.
clear
();
message
<<
"The images will be streamed into "
<<
numberOfStreamDivisions
<<
" parts."
;
GetLogger
()
->
Info
(
message
.
str
());
// Training sample lists
ListSampleType
::
Pointer
sampleList
=
ListSampleType
::
New
();
EstimatorType
::
ParametersType
initialMeans
(
108
*
nb_classes
);
initialMeans
.
Fill
(
0
);
unsigned
int
init_means_index
=
0
;
// Sample dimension and max dimension
unsigned
int
maxDimension
=
SampleType
::
Dimension
;
unsigned
int
sampleSize
=
std
::
min
(
m_InImage
->
GetNumberOfComponentsPerPixel
(),
maxDimension
);
unsigned
int
totalSamples
=
0
;
message
.
clear
();
message
<<
"Sample max possible dimension: "
<<
maxDimension
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
message
.
clear
();
message
<<
"The following sample size will be used: "
<<
sampleSize
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
// local streaming variables
unsigned
int
piece
=
0
;
RegionType
streamingRegion
;
while
((
totalSamples
<
nbsamples
)
&&
(
init_means_index
<
108
*
nb_classes
))
{
double
random
=
randomGen
->
GetVariateWithClosedRange
();
piece
=
static_cast
<
unsigned
int
>
(
random
*
numberOfStreamDivisions
);
streamingRegion
=
splitter
->
GetSplit
(
piece
,
numberOfStreamDivisions
,
largestRegion
);
message
.
clear
();
message
<<
"Processing region: "
<<
streamingRegion
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
m_InImage
->
SetRequestedRegion
(
streamingRegion
);
m_InImage
->
PropagateRequestedRegion
();
m_InImage
->
UpdateOutputData
();
maskImage
->
SetRequestedRegion
(
streamingRegion
);
maskImage
->
PropagateRequestedRegion
();
maskImage
->
UpdateOutputData
();
IteratorType
it
(
m_InImage
,
streamingRegion
);
LabeledIteratorType
m_MaskIt
(
maskImage
,
streamingRegion
);
it
.
GoToBegin
();
m_MaskIt
.
GoToBegin
();
unsigned
int
localNbSamples
=
0
;
// Loop on the image
while
(
!
it
.
IsAtEnd
()
&&
!
m_MaskIt
.
IsAtEnd
()
&&
(
totalSamples
<
nbsamples
)
&&
(
init_means_index
<
(
108
*
nb_classes
)))
{
// If the current pixel is labeled
if
(
m_MaskIt
.
Get
()
>
0
)
{
if
((
rand
()
<
trainingProb
*
RAND_MAX
))
{
SampleType
newSample
;
// build the sample
newSample
.
Fill
(
0
);
for
(
unsigned
int
i
=
0
;
i
<
sampleSize
;
++
i
)
{
newSample
[
i
]
=
it
.
Get
()[
i
];
}
// Update the the sample lists
sampleList
->
PushBack
(
newSample
);
++
totalSamples
;
++
localNbSamples
;
}
else
if
((
init_means_index
<
108
*
nb_classes
)
&&
(
rand
()
<
initProb
*
RAND_MAX
))
{
for
(
unsigned
int
i
=
0
;
i
<
sampleSize
;
++
i
)
{
initialMeans
[
init_means_index
+
i
]
=
it
.
Get
()[
i
];
}
init_means_index
+=
108
;
}
}
++
it
;
++
m_MaskIt
;
}
message
.
clear
();
message
<<
localNbSamples
<<
" samples added to the training set."
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
}
message
.
clear
();
message
<<
"The final training set contains "
<<
totalSamples
<<
" samples."
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
message
.
clear
();
message
<<
"Data sampling completed."
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
/*******************************************/
/* Learning */
/*******************************************/
message
.
clear
();
message
<<
"-- LEARNING --"
<<
std
::
endl
;
message
<<
"Initial centroids are: "
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
message
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
nb_classes
;
++
i
)
{
message
<<
"Class "
<<
i
<<
": "
;
for
(
unsigned
int
j
=
0
;
j
<
sampleSize
;
++
j
)
{
message
<<
initialMeans
[
i
*
108
+
j
]
<<
"
\t
"
;
}
message
<<
std
::
endl
;
}
message
<<
std
::
endl
;
message
.
clear
();
message
<<
"Starting optimization."
<<
std
::
endl
;
message
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
EstimatorType
::
Pointer
estimator
=
EstimatorType
::
New
();
TreeGeneratorType
::
Pointer
treeGenerator
=
TreeGeneratorType
::
New
();
treeGenerator
->
SetSample
(
sampleList
);
treeGenerator
->
SetBucketSize
(
100
);
treeGenerator
->
Update
();
estimator
->
SetParameters
(
initialMeans
);
estimator
->
SetKdTree
(
treeGenerator
->
GetOutput
());
estimator
->
SetMaximumIteration
(
100000000
);
estimator
->
SetCentroidPositionChangesThreshold
(
0.001
);
estimator
->
StartOptimization
();
EstimatorType
::
ParametersType
estimatedMeans
=
estimator
->
GetParameters
();
message
.
clear
();
message
<<
"Optimization completed."
<<
std
::
endl
;
message
<<
std
::
endl
;
message
<<
"Estimated centroids are: "
<<
std
::
endl
;
for
(
unsigned
int
i
=
0
;
i
<
nb_classes
;
++
i
)
{
message
<<
"Class "
<<
i
<<
": "
;
for
(
unsigned
int
j
=
0
;
j
<
sampleSize
;
++
j
)
{
message
<<
estimatedMeans
[
i
*
108
+
j
]
<<
"
\t
"
;
}
message
<<
std
::
endl
;
}
message
<<
std
::
endl
;
message
<<
"Learning completed."
<<
std
::
endl
;
message
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
/*******************************************/
/* Classification */
/*******************************************/
message
.
clear
();
message
<<
"-- CLASSIFICATION --"
<<
std
::
endl
;
message
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
m_Classifier
=
ClassificationFilterType
::
New
();
m_Classifier
->
SetInput
(
m_InImage
);
m_Classifier
->
SetInputMask
(
maskImage
);
m_Classifier
->
SetCentroids
(
estimator
->
GetParameters
());
SetParameterOutputImage
<
LabeledImageType
>
(
"out"
,
m_Classifier
->
GetOutput
());
}
ExtractorType
::
Pointer
m_Extractor
;
ClassificationFilterType
::
Pointer
m_Classifier
;
FloatVectorImageType
::
Pointer
m_InImage
;
};
}
}
OTB_APPLICATION_EXPORT
(
otb
::
Wrapper
::
KMeansClassification
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment