Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
otb
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Julien Cabieces
otb
Commits
7304f4b5
Commit
7304f4b5
authored
13 years ago
by
Jonathan Guinet
Browse files
Options
Downloads
Patches
Plain Diff
ENH: port to new framework
parent
cc10e722
No related branches found
No related tags found
No related merge requests found
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
Applications/Classification/otbKMeansClassification.cxx
+284
-228
284 additions, 228 deletions
Applications/Classification/otbKMeansClassification.cxx
with
284 additions
and
228 deletions
Applications/Classification/otbKMeansClassification.cxx
+
284
−
228
View file @
7304f4b5
/*=========================================================================
Program: ORFEO Toolbox
Language: C++
Date: $Date$
Version: $Revision$
Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
See OTBCopyright.txt for details.
This software is distributed WITHOUT ANY WARRANTY; without even
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
PURPOSE. See the above copyright notices for more information.
=========================================================================*/
#include
"otbWrapperApplication.h"
#include
"otbWrapperApplicationFactory.h"
#include
"otbVectorImage.h"
#include
"otbImage.h"
#include
"otbImageFileReader.h"
#include
"otbStreamingImageFileWriter.h"
#include
"otbImageFileWriter.h"
#include
"otbCommandLineArgumentParser.h"
#include
"itkEuclideanDistance.h"
#include
"itkImageRegionSplitter.h"
#include
"otbStreamingTraits.h"
...
...
@@ -13,283 +30,322 @@
#include
"itkWeightedCentroidKdTreeGenerator.h"
#include
"itkKdTreeBasedKmeansEstimator.h"
#include
"itkMersenneTwisterRandomVariateGenerator.h"
#include
"itkCastImageFilter.h"
#include
"otbMultiToMonoChannelExtractROI.h"
int
main
(
int
argc
,
char
*
argv
[])
namespace
otb
{
namespace
Wrapper
{
// Parse command line parameters
typedef
otb
::
CommandLineArgumentParser
ParserType
;
ParserType
::
Pointer
parser
=
ParserType
::
New
();
parser
->
SetProgramDescription
(
"Unsupervised KMeans image classification"
);
parser
->
AddInputImage
();
parser
->
AddOutputImage
();
parser
->
AddOption
(
"--ValidityMask"
,
"Validity mask"
,
"-vm"
,
1
,
true
);
parser
->
AddOption
(
"--MaxTrainingSetSize"
,
"Size of the training set"
,
"-ts"
,
1
,
true
);
parser
->
AddOption
(
"--TrainingSetProbability"
,
"Probability for a sample to be selected in the training set"
,
"-tp"
,
1
,
true
);
parser
->
AddOption
(
"--NumberOfClasses"
,
"Number of classes"
,
"-nc"
,
1
,
true
);
parser
->
AddOption
(
"--InitialCentroidProbability"
,
"Probability for a pixel to be selected as an initial class centroid"
,
"-cp"
,
1
,
true
);
parser
->
AddOption
(
"--StreamingNumberOfLines"
,
"Number of lines for each streaming block"
,
"-sl"
,
1
,
true
);
typedef
otb
::
CommandLineArgumentParseResult
ParserResultType
;
ParserResultType
::
Pointer
parseResult
=
ParserResultType
::
New
();
try
{
parser
->
ParseCommandLine
(
argc
,
argv
,
parseResult
);
}
catch
(
itk
::
ExceptionObject
&
err
)
{
std
::
string
descriptionException
=
err
.
GetDescription
();
if
(
descriptionException
.
find
(
"ParseCommandLine(): Help Parser"
)
!=
std
::
string
::
npos
)
{
return
EXIT_SUCCESS
;
}
if
(
descriptionException
.
find
(
"ParseCommandLine(): Version Parser"
)
!=
std
::
string
::
npos
)
{
return
EXIT_SUCCESS
;
}
return
EXIT_FAILURE
;
}
// initiating random number generation
itk
::
Statistics
::
MersenneTwisterRandomVariateGenerator
::
Pointer
randomGen
=
itk
::
Statistics
::
MersenneTwisterRandomVariateGenerator
::
New
();
typedef
otb
::
Image
<
FloatVectorImageType
::
InternalPixelType
,
2
>
ImageReaderType
;
std
::
string
infname
=
parseResult
->
GetInputImage
();
std
::
string
maskfname
=
parseResult
->
GetParameterString
(
"--ValidityMask"
,
0
);
std
::
string
outfname
=
parseResult
->
GetOutputImage
();
const
unsigned
int
nbsamples
=
parseResult
->
GetParameterUInt
(
"--MaxTrainingSetSize"
);
const
double
trainingProb
=
parseResult
->
GetParameterDouble
(
"--TrainingSetProbability"
);
const
double
initProb
=
parseResult
->
GetParameterDouble
(
"--InitialCentroidProbability"
);
const
unsigned
int
nbLinesForStreaming
=
parseResult
->
GetParameterUInt
(
"--StreamingNumberOfLines"
);
const
unsigned
int
nb_classes
=
parseResult
->
GetParameterUInt
(
"--NumberOfClasses"
);
typedef
UInt8ImageType
LabeledImageType
;
typedef
ImageReaderType
::
PixelType
PixelType
;
typedef
unsigned
short
PixelType
;
typedef
unsigned
short
LabeledPixelType
;
typedef
itk
::
FixedArray
<
PixelType
,
108
>
SampleType
;
typedef
itk
::
Statistics
::
ListSample
<
SampleType
>
ListSampleType
;
typedef
itk
::
Statistics
::
WeightedCentroidKdTreeGenerator
<
ListSampleType
>
TreeGeneratorType
;
typedef
TreeGeneratorType
::
KdTreeType
TreeType
;
typedef
itk
::
Statistics
::
KdTreeBasedKmeansEstimator
<
TreeType
>
EstimatorType
;
typedef
itk
::
CastImageFilter
<
FloatImageListType
,
FloatImageType
>
CastMaskFilterType
;
typedef
otb
::
MultiToMonoChannelExtractROI
<
FloatVectorImageType
::
InternalPixelType
,
LabeledImageType
::
InternalPixelType
>
ExtractorType
;
typedef
otb
::
VectorImage
<
PixelType
,
2
>
ImageType
;
typedef
otb
::
Image
<
LabeledPixelType
,
2
>
LabeledImageType
;
typedef
otb
::
ImageFileReader
<
ImageType
>
ImageReaderType
;
typedef
otb
::
ImageFileReader
<
LabeledImageType
>
LabeledImageReaderType
;
typedef
otb
::
StreamingImageFileWriter
<
LabeledImageType
>
WriterType
;
typedef
otb
::
StreamingTraits
<
FloatVectorImageType
>
StreamingTraitsType
;
typedef
itk
::
ImageRegionSplitter
<
2
>
SplitterType
;
typedef
ImageReaderType
::
RegionType
RegionType
;
typedef
itk
::
FixedArray
<
PixelType
,
108
>
SampleType
;
typedef
itk
::
Statistics
::
ListSample
<
SampleType
>
ListSampleType
;
typedef
itk
::
Statistics
::
WeightedCentroidKdTreeGenerator
<
ListSampleType
>
TreeGeneratorType
;
typedef
TreeGeneratorType
::
KdTreeType
TreeType
;
typedef
itk
::
Statistics
::
KdTreeBasedKmeansEstimator
<
TreeType
>
EstimatorType
;
typedef
itk
::
ImageRegionConstIterator
<
FloatVectorImageType
>
IteratorType
;
typedef
itk
::
ImageRegionConstIterator
<
LabeledImageType
>
LabeledIteratorType
;
typedef
otb
::
StreamingTraits
<
ImageType
>
StreamingTraitsType
;
typedef
itk
::
ImageRegionSplitter
<
2
>
SplitterType
;
typedef
ImageType
::
RegionType
RegionType
;
typedef
otb
::
KMeansImageClassificationFilter
<
FloatVectorImageType
,
LabeledImageType
,
108
>
ClassificationFilterType
;
typedef
itk
::
ImageRegionConstIterator
<
ImageType
>
IteratorType
;
typedef
itk
::
ImageRegionConstIterator
<
LabeledImageType
>
LabeledIteratorType
;
class
KMeansClassification
:
public
Application
{
public:
/** Standard class typedefs. */
typedef
KMeansClassification
Self
;
typedef
Application
Superclass
;
typedef
itk
::
SmartPointer
<
Self
>
Pointer
;
typedef
itk
::
SmartPointer
<
const
Self
>
ConstPointer
;
typedef
otb
::
KMeansImageClassificationFilter
<
ImageType
,
LabeledImageType
,
108
>
ClassificationFilterType
;
/** Standard macro */
itkNewMacro
(
Self
);
itkTypeMacro
(
KMeansClassification
,
otb
::
Application
);
ImageReaderType
::
Pointer
reader
=
ImageReaderType
::
New
();
LabeledImageReaderType
::
Pointer
maskReader
=
LabeledImageReaderType
::
New
();
private:
KMeansClassification
()
{
SetName
(
"KMeansClassification"
);
SetDescription
(
"Unsupervised KMeans image classification."
);
}
reader
->
SetFileName
(
infname
);
maskReader
->
SetFileName
(
maskfname
);
virtual
~
KMeansClassification
()
{
}
/*******************************************/
/* Sampling data */
/*******************************************/
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"-- SAMPLING DATA --"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
void
DoCreateParameters
()
{
// Update input images information
reader
->
GenerateOutputInformation
();
maskReader
->
GenerateOutputInformation
();
AddParameter
(
ParameterType_InputImage
,
"in"
,
"Input Image"
);
AddParameter
(
ParameterType_OutputImage
,
"out"
,
"Output Image"
);
AddParameter
(
ParameterType_InputImage
,
"vm"
,
"Validity Mask"
);
AddParameter
(
ParameterType_Int
,
"ts"
,
"Size of the training set"
);
SetParameterInt
(
"ts"
,
100
);
AddParameter
(
ParameterType_Float
,
"tp"
,
"Probability for a sample to be selected in the training set"
);
SetParameterFloat
(
"tp"
,
0.5
);
AddParameter
(
ParameterType_Int
,
"nc"
,
"Number of classes"
);
SetParameterInt
(
"nc"
,
3
);
AddParameter
(
ParameterType_Float
,
"cp"
,
"Probability for a pixel to be selected as an initial class centroid"
);
SetParameterFloat
(
"cp"
,
0.8
);
AddParameter
(
ParameterType_Int
,
"sl"
,
"Number of lines for each streaming block"
);
SetParameterInt
(
"sl"
,
1000
);
}
if
(
reader
->
GetOutput
()
->
GetLargestPossibleRegion
()
!=
maskReader
->
GetOutput
()
->
GetLargestPossibleRegion
()
)
void
DoUpdateParameters
()
{
std
::
cerr
<<
"Mask image and input image have different sizes."
<<
std
::
endl
;
return
EXIT_FAILURE
;
// Nothing to do here : all parameters are independent
}
RegionType
largestRegion
=
reader
->
GetOutput
()
->
GetLargestPossibleRegion
();
// Setting up local streaming capabilities
SplitterType
::
Pointer
splitter
=
SplitterType
::
New
();
unsigned
int
numberOfStreamDivisions
=
StreamingTraitsType
::
CalculateNumberOfStreamDivisions
(
reader
->
GetOutput
(),
largestRegion
,
splitter
,
otb
::
SET_BUFFER_NUMBER_OF_LINES
,
0
,
0
,
nbLinesForStreaming
);
std
::
cout
<<
"The images will be streamed into "
<<
numberOfStreamDivisions
<<
" parts."
<<
std
::
endl
;
// Training sample lists
ListSampleType
::
Pointer
sampleList
=
ListSampleType
::
New
();
EstimatorType
::
ParametersType
initialMeans
(
108
*
nb_classes
);
initialMeans
.
Fill
(
0
);
unsigned
int
init_means_index
=
0
;
// Sample dimension and max dimension
unsigned
int
maxDimension
=
SampleType
::
Dimension
;
unsigned
int
sampleSize
=
std
::
min
(
reader
->
GetOutput
()
->
GetNumberOfComponentsPerPixel
(),
maxDimension
);
unsigned
int
totalSamples
=
0
;
std
::
cout
<<
"Sample max possible dimension: "
<<
maxDimension
<<
std
::
endl
;
std
::
cout
<<
"The following sample size will be used: "
<<
sampleSize
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
// local streaming variables
unsigned
int
piece
=
0
;
RegionType
streamingRegion
;
while
((
totalSamples
<
nbsamples
)
&&
(
init_means_index
<
108
*
nb_classes
))
void
DoExecute
()
{
double
random
=
randomGen
->
GetVariateWithClosedRange
();
piece
=
static_cast
<
unsigned
int
>
(
random
*
numberOfStreamDivisions
);
GetLogger
()
->
Debug
(
"Entering DoExecute
\n
"
);
// initiating random number generation
itk
::
Statistics
::
MersenneTwisterRandomVariateGenerator
::
Pointer
randomGen
=
itk
::
Statistics
::
MersenneTwisterRandomVariateGenerator
::
New
();
m_InImage
=
GetParameterImage
(
"in"
);
m_Extractor
=
ExtractorType
::
New
();
m_Extractor
->
SetInput
(
GetParameterImage
(
"vm"
));
m_Extractor
->
SetChannel
(
1
);
m_Extractor
->
UpdateOutputInformation
();
LabeledImageType
::
Pointer
maskImage
=
m_Extractor
->
GetOutput
();
std
::
ostringstream
message
(
""
);
const
unsigned
int
nbsamples
=
GetParameterInt
(
"ts"
);
const
double
trainingProb
=
GetParameterFloat
(
"tp"
);
const
double
initProb
=
GetParameterFloat
(
"cp"
);
const
unsigned
int
nb_classes
=
GetParameterInt
(
"nc"
);
const
unsigned
int
nbLinesForStreaming
=
GetParameterInt
(
"sl"
);
/*******************************************/
/* Sampling data */
/*******************************************/
GetLogger
()
->
Info
(
"-- SAMPLING DATA --"
);
// Update input images information
m_InImage
->
UpdateOutputInformation
();
maskImage
->
UpdateOutputInformation
();
if
(
m_InImage
->
GetLargestPossibleRegion
()
!=
maskImage
->
GetLargestPossibleRegion
())
{
GetLogger
()
->
Error
(
"Mask image and input image have different sizes."
);
}
streamingRegion
=
splitter
->
GetSplit
(
piece
,
numberOfStreamDivisions
,
largestRegion
);
RegionType
largestRegion
=
m_InImage
->
GetLargestPossibleRegion
();
// Setting up local streaming capabilities
SplitterType
::
Pointer
splitter
=
SplitterType
::
New
();
unsigned
int
numberOfStreamDivisions
=
StreamingTraitsType
::
CalculateNumberOfStreamDivisions
(
m_InImage
,
largestRegion
,
splitter
,
otb
::
SET_BUFFER_NUMBER_OF_LINES
,
0
,
0
,
nbLinesForStreaming
);
message
.
clear
();
message
<<
"The images will be streamed into "
<<
numberOfStreamDivisions
<<
" parts."
;
GetLogger
()
->
Info
(
message
.
str
());
// Training sample lists
ListSampleType
::
Pointer
sampleList
=
ListSampleType
::
New
();
EstimatorType
::
ParametersType
initialMeans
(
108
*
nb_classes
);
initialMeans
.
Fill
(
0
);
unsigned
int
init_means_index
=
0
;
// Sample dimension and max dimension
unsigned
int
maxDimension
=
SampleType
::
Dimension
;
unsigned
int
sampleSize
=
std
::
min
(
m_InImage
->
GetNumberOfComponentsPerPixel
(),
maxDimension
);
unsigned
int
totalSamples
=
0
;
message
.
clear
();
message
<<
"Sample max possible dimension: "
<<
maxDimension
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
message
.
clear
();
message
<<
"The following sample size will be used: "
<<
sampleSize
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
// local streaming variables
unsigned
int
piece
=
0
;
RegionType
streamingRegion
;
while
((
totalSamples
<
nbsamples
)
&&
(
init_means_index
<
108
*
nb_classes
))
{
double
random
=
randomGen
->
GetVariateWithClosedRange
();
piece
=
static_cast
<
unsigned
int
>
(
random
*
numberOfStreamDivisions
);
std
::
cout
<<
"Process
ing
r
egion
: "
<<
streamingRegion
<<
std
::
endl
;
stream
ing
R
egion
=
splitter
->
GetSplit
(
piece
,
numberOfStreamDivisions
,
largestRegion
)
;
reader
->
GetOutput
()
->
SetRequestedRegion
(
streamingRegion
);
reader
->
GetOutput
()
->
PropagateRequestedRegion
()
;
reader
->
GetOutput
()
->
UpdateOutputData
(
);
message
.
clear
(
);
message
<<
"Processing region: "
<<
streamingRegion
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
()
);
maskReader
->
GetOutput
()
->
SetRequestedRegion
(
streamingRegion
);
maskReader
->
GetOutput
()
->
PropagateRequestedRegion
();
maskReader
->
GetOutput
()
->
UpdateOutputData
();
m_InImage
->
SetRequestedRegion
(
streamingRegion
);
m_InImage
->
PropagateRequestedRegion
();
m_InImage
->
UpdateOutputData
();
IteratorType
it
(
reader
->
GetOutput
(),
streamingRegion
);
LabeledIteratorType
maskIt
(
maskReader
->
GetOutput
(),
streamingRegion
);
maskImage
->
SetRequestedRegion
(
streamingRegion
);
maskImage
->
PropagateRequestedRegion
();
maskImage
->
UpdateOutputData
();
it
.
GoToB
egin
(
);
maskIt
.
GoToB
egin
(
);
IteratorType
it
(
m_InImage
,
streamingR
egi
o
n
);
LabeledIteratorType
m_MaskIt
(
maskImage
,
streamingR
egi
o
n
);
unsigned
int
localNbSamples
=
0
;
it
.
GoToBegin
();
m_MaskIt
.
GoToBegin
();
// Loop on the image
while
(
!
it
.
IsAtEnd
()
&&!
maskIt
.
IsAtEnd
()
&&
(
totalSamples
<
nbsamples
)
&&
(
init_means_index
<
108
*
nb_classes
))
{
// If the current pixel is labeled
if
(
maskIt
.
Get
()
>
0
)
{
if
((
rand
()
<
trainingProb
*
RAND_MAX
))
{
SampleType
newSample
;
unsigned
int
localNbSamples
=
0
;
// build the sample
newSample
.
Fill
(
0
);
for
(
unsigned
int
i
=
0
;
i
<
sampleSize
;
++
i
)
// Loop on the image
while
(
!
it
.
IsAtEnd
()
&&
!
m_MaskIt
.
IsAtEnd
()
&&
(
totalSamples
<
nbsamples
)
&&
(
init_means_index
<
(
108
*
nb_classes
)))
{
// If the current pixel is labeled
if
(
m_MaskIt
.
Get
()
>
0
)
{
newSample
[
i
]
=
it
.
Get
()[
i
];
if
((
rand
()
<
trainingProb
*
RAND_MAX
))
{
SampleType
newSample
;
// build the sample
newSample
.
Fill
(
0
);
for
(
unsigned
int
i
=
0
;
i
<
sampleSize
;
++
i
)
{
newSample
[
i
]
=
it
.
Get
()[
i
];
}
// Update the the sample lists
sampleList
->
PushBack
(
newSample
);
++
totalSamples
;
++
localNbSamples
;
}
else
if
((
init_means_index
<
108
*
nb_classes
)
&&
(
rand
()
<
initProb
*
RAND_MAX
))
{
for
(
unsigned
int
i
=
0
;
i
<
sampleSize
;
++
i
)
{
initialMeans
[
init_means_index
+
i
]
=
it
.
Get
()[
i
];
}
init_means_index
+=
108
;
}
}
// Update the the sample lists
sampleList
->
PushBack
(
newSample
);
++
totalSamples
;
++
localNbSamples
;
++
it
;
++
m_MaskIt
;
}
else
if
((
init_means_index
<
108
*
nb_classes
)
&&
(
rand
()
<
initProb
*
RAND_MAX
))
message
.
clear
();
message
<<
localNbSamples
<<
" samples added to the training set."
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
}
message
.
clear
();
message
<<
"The final training set contains "
<<
totalSamples
<<
" samples."
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
message
.
clear
();
message
<<
"Data sampling completed."
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
/*******************************************/
/* Learning */
/*******************************************/
message
.
clear
();
message
<<
"-- LEARNING --"
<<
std
::
endl
;
message
<<
"Initial centroids are: "
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
message
.
clear
();
for
(
unsigned
int
i
=
0
;
i
<
nb_classes
;
++
i
)
{
message
<<
"Class "
<<
i
<<
": "
;
for
(
unsigned
int
j
=
0
;
j
<
sampleSize
;
++
j
)
{
for
(
unsigned
int
i
=
0
;
i
<
sampleSize
;
++
i
)
{
initialMeans
[
init_means_index
+
i
]
=
it
.
Get
()[
i
];
}
init_means_index
+=
108
;
message
<<
initialMeans
[
i
*
108
+
j
]
<<
"
\t
"
;
}
message
<<
std
::
endl
;
}
++
it
;
++
maskIt
;
}
std
::
cout
<<
localNbSamples
<<
" samples added to the training set."
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
}
message
<<
std
::
endl
;
std
::
cout
<<
"The final training set contains "
<<
totalSamples
<<
" samples."
<<
std
::
endl
;
message
.
clear
();
message
<<
"Starting optimization."
<<
std
::
endl
;
message
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
EstimatorType
::
Pointer
estimator
=
EstimatorType
::
New
();
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"Data sampling completed."
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
TreeGeneratorType
::
Pointer
treeGenerator
=
TreeGeneratorType
::
New
();
treeGenerator
->
SetSample
(
sampleList
);
treeGenerator
->
SetBucketSize
(
100
);
treeGenerator
->
Update
();
/*******************************************/
/* Learning */
/*******************************************/
estimator
->
SetParameters
(
initialMeans
);
estimator
->
SetKdTree
(
treeGenerator
->
GetOutput
());
estimator
->
SetMaximumIteration
(
100000000
);
estimator
->
SetCentroidPositionChangesThreshold
(
0.001
);
estimator
->
StartOptimization
();
std
::
cout
<<
"-- LEARNING --"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
EstimatorType
::
ParametersType
estimatedMeans
=
estimator
->
GetParameters
();
message
.
clear
();
message
<<
"Optimization completed."
<<
std
::
endl
;
message
<<
std
::
endl
;
message
<<
"Estimated centroids are: "
<<
std
::
endl
;
std
::
cout
<<
"Initial centroids are: "
<<
std
::
endl
;
for
(
unsigned
int
i
=
0
;
i
<
nb_classes
;
++
i
)
{
message
<<
"Class "
<<
i
<<
": "
;
for
(
unsigned
int
j
=
0
;
j
<
sampleSize
;
++
j
)
{
message
<<
estimatedMeans
[
i
*
108
+
j
]
<<
"
\t
"
;
}
message
<<
std
::
endl
;
}
for
(
unsigned
int
i
=
0
;
i
<
nb_classes
;
++
i
)
{
std
::
cout
<<
"Class "
<<
i
<<
": "
;
for
(
unsigned
int
j
=
0
;
j
<
sampleSize
;
++
j
)
{
std
::
cout
<<
initialMeans
[
i
*
108
+
j
]
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
message
<<
std
::
endl
;
message
<<
"Learning completed."
<<
std
::
endl
;
message
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
std
::
cout
<<
"Starting optimization."
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
EstimatorType
::
Pointer
estimator
=
EstimatorType
::
New
();
/*******************************************/
/* Classification */
/*******************************************/
message
.
clear
();
message
<<
"-- CLASSIFICATION --"
<<
std
::
endl
;
message
<<
std
::
endl
;
GetLogger
()
->
Info
(
message
.
str
());
TreeGeneratorType
::
Pointer
treeGenerator
=
TreeGeneratorType
::
New
();
treeGenerator
->
SetSample
(
sampleList
);
treeGenerator
->
SetBucketSize
(
100
);
treeGenerator
->
Update
();
m_Classifier
=
ClassificationFilterType
::
New
();
estimator
->
SetParameters
(
initialMeans
);
estimator
->
SetKdTree
(
treeGenerator
->
GetOutput
());
estimator
->
SetMaximumIteration
(
100000000
);
estimator
->
SetCentroidPositionChangesThreshold
(
0.001
);
estimator
->
StartOptimization
();
m_Classifier
->
SetInput
(
m_InImage
);
m_Classifier
->
SetInputMask
(
maskImage
);
EstimatorType
::
ParametersType
estimatedMeans
=
estimator
->
GetParameters
();
m_Classifier
->
SetCentroids
(
estimator
->
GetParameters
()
)
;
std
::
cout
<<
"Optimization completed."
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"Estimated centroids are: "
<<
std
::
endl
;
SetParameterOutputImage
<
LabeledImageType
>
(
"out"
,
m_Classifier
->
GetOutput
());
for
(
unsigned
int
i
=
0
;
i
<
nb_classes
;
++
i
)
{
std
::
cout
<<
"Class "
<<
i
<<
": "
;
for
(
unsigned
int
j
=
0
;
j
<
sampleSize
;
++
j
)
{
std
::
cout
<<
estimatedMeans
[
i
*
108
+
j
]
<<
"
\t
"
;
}
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
std
::
endl
;
std
::
cout
<<
"Learning completed."
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
ExtractorType
::
Pointer
m_Extractor
;
ClassificationFilterType
::
Pointer
m_Classifier
;
FloatVectorImageType
::
Pointer
m_InImage
;
};
/*******************************************/
/* Classification */
/*******************************************/
}
}
std
::
cout
<<
"-- CLASSIFICATION --"
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
ClassificationFilterType
::
Pointer
classifier
=
ClassificationFilterType
::
New
();
classifier
->
SetInput
(
reader
->
GetOutput
());
classifier
->
SetInputMask
(
maskReader
->
GetOutput
());
classifier
->
SetCentroids
(
estimator
->
GetParameters
());
OTB_APPLICATION_EXPORT
(
otb
::
Wrapper
::
KMeansClassification
)
WriterType
::
Pointer
writer
=
WriterType
::
New
();
writer
->
SetFileName
(
outfname
);
writer
->
SetInput
(
classifier
->
GetOutput
());
writer
->
SetNumberOfDivisionsStrippedStreaming
(
numberOfStreamDivisions
);
writer
->
Update
();
std
::
cout
<<
"Classification completed."
<<
std
::
endl
;
std
::
cout
<<
std
::
endl
;
return
EXIT_SUCCESS
;
}
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment