otbKMeansClassification.cxx 16.9 KB
Newer Older
Jonathan Guinet's avatar
Jonathan Guinet committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/*=========================================================================

 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/

#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

22
23
24
25
26
27
#include "otbVectorImage.h"
#include "otbStreamingTraits.h"
#include "itkImageRegionConstIterator.h"
#include "itkListSample.h"
#include "itkWeightedCentroidKdTreeGenerator.h"
#include "itkKdTreeBasedKmeansEstimator.h"
28
29
30
31
#include "otbStreamingShrinkImageFilter.h"
#include "otbChangeLabelImageFilter.h"
#include "otbRAMDrivenStrippedStreamingManager.h"

32
33
34
35
36
#include "otbChangeLabelImageFilter.h"
#include "itkLabelToRGBImageFilter.h"
#include "otbReliefColormapFunctor.h"
#include "itkScalarToRGBColormapImageFilter.h"

37

Jonathan Guinet's avatar
Jonathan Guinet committed
38
39
namespace otb
{
40
41


Jonathan Guinet's avatar
Jonathan Guinet committed
42
namespace Wrapper
43
44
{

45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
namespace Functor
{
template <class TSample, class TLabel> class KMeansFunctor
{
public:
  /** operator */
  TLabel operator ()(const TSample& sample) const
  {
    typename CentroidMapType::const_iterator it = m_CentroidsMap.begin();

    if (it == m_CentroidsMap.end())
      {
      return 0;
      }

    TLabel resp = it->first;
    double minDist = m_Distance->Evaluate(sample, it->second);
    ++it;

    while (it != m_CentroidsMap.end())
      {
      double dist = m_Distance->Evaluate(sample, it->second);

      if (dist < minDist)
        {
        resp = it->first;
        minDist = dist;
        }
      ++it;
      }
    return resp;
  }

  /** Add a new centroid */
  void AddCentroid(const TLabel& label, const TSample& centroid)
  {
    m_CentroidsMap[label] = centroid;
  }

  /** Constructor */
  KMeansFunctor() : m_CentroidsMap(), m_Distance()
  {
    m_Distance = DistanceType::New();
  }

  bool operator !=(const KMeansFunctor& other) const
  {
    return m_CentroidsMap != other.m_CentroidsMap;
  }

private:
  typedef std::map<TLabel, TSample>                   CentroidMapType;
97
  typedef itk::Statistics::EuclideanDistanceMetric<TSample> DistanceType;
98
99
100
101
102
103
104

  CentroidMapType m_CentroidsMap;
  typename DistanceType::Pointer m_Distance;
};
}


105
typedef FloatImageType::PixelType PixelType;
106
typedef UInt16ImageType   LabeledImageType;
107

108
typedef UInt16VectorImageType       VectorImageType;
109
110
111
112
113
typedef VectorImageType::PixelType  VectorPixelType;
typedef UInt8RGBImageType           RGBImageType;
typedef RGBImageType::PixelType     RGBPixelType;


114
115
116
117
typedef LabeledImageType::PixelType LabelType;


typedef FloatVectorImageType::PixelType                               SampleType;
Jonathan Guinet's avatar
Jonathan Guinet committed
118
119
120
121
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
122
typedef RAMDrivenStrippedStreamingManager<FloatVectorImageType> RAMDrivenStrippedStreamingManagerType;
123
124


Jonathan Guinet's avatar
Jonathan Guinet committed
125
typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType;
126
typedef itk::ImageRegionConstIterator<UInt8ImageType> LabeledIteratorType;
127
128
129
130

typedef otb::StreamingShrinkImageFilter<FloatVectorImageType,
     FloatVectorImageType>              ImageSamplingFilterType;

131
132
typedef otb::StreamingShrinkImageFilter<UInt8ImageType,
    UInt8ImageType>              MaskSamplingFilterType;
133
134
135
136
typedef Functor::KMeansFunctor<SampleType, LabelType> KMeansFunctorType;
typedef itk::UnaryFunctorImageFilter<FloatVectorImageType,
    LabeledImageType, KMeansFunctorType>     KMeansFilterType;

137

138
139
140
141
142
143
144
145
146
147
148
149
150
151
// Manual label LUT
 typedef otb::ChangeLabelImageFilter
 <LabeledImageType, VectorImageType>    ChangeLabelFilterType;

 // Continuous LUT mapping
  typedef itk::ScalarToRGBColormapImageFilter<LabeledImageType, RGBImageType>      ColorMapFilterType;


  typedef otb::Functor::ReliefColormapFunctor
   <LabelType, RGBPixelType>           ReliefColorMapFunctorType;

  typedef otb::ImageMetadataInterfaceBase ImageMetadataInterfaceType;


Jonathan Guinet's avatar
Jonathan Guinet committed
152
153
154
155
156
157
158
159
class KMeansClassification: public Application
{
public:
  /** Standard class typedefs. */
  typedef KMeansClassification Self;
  typedef Application Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
160

Jonathan Guinet's avatar
Jonathan Guinet committed
161
162
  /** Standard macro */
  itkNewMacro(Self);
163

Jonathan Guinet's avatar
Jonathan Guinet committed
164
  itkTypeMacro(KMeansClassification, otb::Application);
165

Jonathan Guinet's avatar
Jonathan Guinet committed
166
private:
167
  void DoInit() ITK_OVERRIDE
Jonathan Guinet's avatar
Jonathan Guinet committed
168
169
  {
    SetName("KMeansClassification");
170
171
    SetDescription("Unsupervised KMeans image classification");

172
    SetDocName("Unsupervised KMeans image classification");
173
    SetDocLongDescription("Performs unsupervised KMeans image classification.");
174
175
176
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");
177

178
    AddDocTag(Tags::Learning);
179
180
	AddDocTag(Tags::Segmentation);
	
Jonathan Guinet's avatar
Jonathan Guinet committed
181
    AddParameter(ParameterType_InputImage, "in", "Input Image");
182
    SetParameterDescription("in", "Input image to classify.");
Jonathan Guinet's avatar
Jonathan Guinet committed
183
    AddParameter(ParameterType_OutputImage, "out", "Output Image");
184
    SetParameterDescription("out", "Output image containing the class indexes.");
185
    SetDefaultOutputPixelType("out",ImagePixelType_uint8);
186
187
188

    AddRAMParameter();

Jonathan Guinet's avatar
Jonathan Guinet committed
189
    AddParameter(ParameterType_InputImage, "vm", "Validity Mask");
190
    SetParameterDescription("vm", "Validity mask. Only non-zero pixels will be used to estimate KMeans modes.");
191
    MandatoryOff("vm");
192
    AddParameter(ParameterType_Int, "ts", "Training set size");
193
    SetParameterDescription("ts", "Size of the training set (in pixels).");
194
    SetDefaultParameterInt("ts", 100);
195
    MandatoryOff("ts");
Jonathan Guinet's avatar
Jonathan Guinet committed
196
    AddParameter(ParameterType_Int, "nc", "Number of classes");
197
    SetParameterDescription("nc", "Number of modes, which will be used to generate class membership.");
198
    SetDefaultParameterInt("nc", 5);
199
    AddParameter(ParameterType_Int, "maxit", "Maximum number of iterations");
200
    SetParameterDescription("maxit", "Maximum number of iterations for the learning step.");
201
202
    SetDefaultParameterInt("maxit", 1000);
    MandatoryOff("maxit");
203
    AddParameter(ParameterType_Float, "ct", "Convergence threshold");
204
    SetParameterDescription("ct", "Convergence threshold for class centroid  (L2 distance, by default 0.0001).");
205
206
    SetDefaultParameterFloat("ct", 0.0001);
    MandatoryOff("ct");
207
    AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename");
208
    SetParameterDescription("outmeans", "Output text file containing centroid positions");
209
210
    MandatoryOff("outmeans");

211
    AddRANDParameter();
212

213
    // Doc example parameter settings
214
    SetDocExampleParameterValue("in", "QB_1_ortho.tif");
215
    SetDocExampleParameterValue("ts", "1000");
216
    SetDocExampleParameterValue("nc", "5");
217
    SetDocExampleParameterValue("maxit", "1000");
218
    SetDocExampleParameterValue("ct", "0.0001");
219
    SetDocExampleParameterValue("out", "ClassificationFilterOutput.tif");
Jonathan Guinet's avatar
Jonathan Guinet committed
220
  }
221

222
  void DoUpdateParameters() ITK_OVERRIDE
223
  {
224
225
226
227
228
229
230
231
232
    // test of input image //
    if (HasValue("in"))
      {
      // input image
      FloatVectorImageType::Pointer inImage = GetParameterImage("in");

      RAMDrivenStrippedStreamingManagerType::Pointer streamingManager = RAMDrivenStrippedStreamingManagerType::New();
      int availableRAM = GetParameterInt("ram");
      streamingManager->SetAvailableRAMInMB(availableRAM);
233
      float bias = 1.5; // empirical value
234
235
236
237
238
239
240
241
242
243
244
245
      streamingManager->SetBias(bias);
      FloatVectorImageType::RegionType largestRegion = inImage->GetLargestPossibleRegion();
      FloatVectorImageType::SizeType largestRegionSize = largestRegion.GetSize();
      streamingManager->PrepareStreaming(inImage, largestRegion);

      unsigned long nbDivisions = streamingManager->GetNumberOfSplits();
      unsigned long largestPixNb = largestRegionSize[0] * largestRegionSize[1];

      unsigned long maxPixNb = largestPixNb / nbDivisions;

      if (GetParameterInt("ts") > static_cast<int> (maxPixNb))
        {
246
247
        otbAppLogWARNING("The available RAM is too small to process this sample size of " << GetParameterInt("ts") <<
            " pixels. The sample size will be reduced to " << maxPixNb << " pixels." << std::endl);
248
        this->SetParameterInt("ts",maxPixNb, false);
249
250
251
252
        }

      this->SetMaximumParameterIntValue("ts", maxPixNb);
      }
253
254
  }

255
  void DoExecute() ITK_OVERRIDE
256
  {
Jonathan Guinet's avatar
Jonathan Guinet committed
257
258
259
    GetLogger()->Debug("Entering DoExecute\n");

    m_InImage = GetParameterImage("in");
260
    m_InImage->UpdateOutputInformation();
261
    UInt8ImageType::Pointer maskImage;
262

Jonathan Guinet's avatar
Jonathan Guinet committed
263
264
    std::ostringstream message("");

265
266
    int nbsamples = GetParameterInt("ts");
    const unsigned int nbClasses = GetParameterInt("nc");
Jonathan Guinet's avatar
Jonathan Guinet committed
267
268
269

    /*******************************************/
    /*           Sampling data                 */
270
271
272
    /*******************************************/

    otbAppLogINFO("-- SAMPLING DATA --"<<std::endl);
Jonathan Guinet's avatar
Jonathan Guinet committed
273
274
275
276

    // Update input images information
    m_InImage->UpdateOutputInformation();

Jonathan Guinet's avatar
Jonathan Guinet committed
277
278
279
280
    bool maskFlag = IsParameterEnabled("vm");
    if (maskFlag)
      {
      otbAppLogINFO("sample choice using mask "<<std::endl);
281
      maskImage = GetParameterUInt8Image("vm");
Jonathan Guinet's avatar
Jonathan Guinet committed
282
283
284
285
      maskImage->UpdateOutputInformation();
      if (m_InImage->GetLargestPossibleRegion() != maskImage->GetLargestPossibleRegion())
        {
        GetLogger()->Error("Mask image and input image have different sizes.");
286
        return;
Jonathan Guinet's avatar
Jonathan Guinet committed
287
288
        }
      }
289

Jonathan Guinet's avatar
Jonathan Guinet committed
290
291
    // Training sample lists
    ListSampleType::Pointer sampleList = ListSampleType::New();
292
    sampleList->SetMeasurementVectorSize(m_InImage->GetNumberOfComponentsPerPixel());
293

294
    //unsigned int init_means_index = 0;
Jonathan Guinet's avatar
Jonathan Guinet committed
295
296

    // Sample dimension and max dimension
297
298
    const unsigned int nbComp = m_InImage->GetNumberOfComponentsPerPixel();
    unsigned int sampleSize = nbComp;
Jonathan Guinet's avatar
Jonathan Guinet committed
299
300
    unsigned int totalSamples = 0;

301
    // sampleSize = std::min(nbComp, maxDim);
Jonathan Guinet's avatar
Jonathan Guinet committed
302

303
304
305
306
307
308
309
    EstimatorType::ParametersType initialMeans(nbComp * nbClasses);
    initialMeans.Fill(0);

    // use image and mask shrink

    ImageSamplingFilterType::Pointer imageSampler = ImageSamplingFilterType::New();
    imageSampler->SetInput(m_InImage);
310

311
    double theoricNBSamplesForKMeans = nbsamples;
312

313
314
    const double upperThresholdNBSamplesForKMeans = 1000 * 1000;
    const double actualNBSamplesForKMeans = std::min(theoricNBSamplesForKMeans, upperThresholdNBSamplesForKMeans);
315

316
    otbAppLogINFO(<< actualNBSamplesForKMeans << " is the maximum sample size that will be used." << std::endl);
317

318
319
320
321
322
323
    const double shrinkFactor = vcl_floor(
                                          vcl_sqrt(
                                                   m_InImage->GetLargestPossibleRegion().GetNumberOfPixels()
                                                       / actualNBSamplesForKMeans));
    imageSampler->SetShrinkFactor(shrinkFactor);
    imageSampler->Update();
324

325
326
    MaskSamplingFilterType::Pointer maskSampler;
    LabeledIteratorType m_MaskIt;
Jonathan Guinet's avatar
Jonathan Guinet committed
327
    if (maskFlag)
328
329
330
331
332
      {
      maskSampler = MaskSamplingFilterType::New();
      maskSampler->SetInput(maskImage);
      maskSampler->SetShrinkFactor(shrinkFactor);
      maskSampler->Update();
Jonathan Guinet's avatar
Jonathan Guinet committed
333
      m_MaskIt = LabeledIteratorType(maskSampler->GetOutput(), maskSampler->GetOutput()->GetLargestPossibleRegion());
334
335
      m_MaskIt.GoToBegin();
      }
336
    // Then, build the sample list
337

338
    IteratorType it(imageSampler->GetOutput(), imageSampler->GetOutput()->GetLargestPossibleRegion());
339

340
341
342
343
344
345
    it.GoToBegin();

    SampleType min;
    SampleType max;
    SampleType sample;
    //first sample
346

347
    itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen=itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance();
348

349
    //randGen->Initialize();
350

Jonathan Guinet's avatar
Jonathan Guinet committed
351
    if (maskFlag)
352
      {
353
      while (!it.IsAtEnd() && !m_MaskIt.IsAtEnd() && (m_MaskIt.Get() <= 0))
Jonathan Guinet's avatar
Jonathan Guinet committed
354
355
356
357
        {
        ++it;
        ++m_MaskIt;
        }
358
359
360
361
362
363
364

      // If the mask is empty after the subsampling
      if (m_MaskIt.IsAtEnd())
        {
        GetLogger()->Error("The mask image is empty after subsampling. Please increase the training set size.");
        return;
        }
365
      }
366

367
368
369
    min = it.Get();
    max = it.Get();
    sample = it.Get();
370

371
372
373
    sampleList->PushBack(sample);

    ++it;
374

Jonathan Guinet's avatar
Jonathan Guinet committed
375
    if (maskFlag)
376
      {
Jonathan Guinet's avatar
Jonathan Guinet committed
377
      ++m_MaskIt;
378
379
      }

380
    totalSamples = 1;
381
    bool selectSample;
Jonathan Guinet's avatar
Jonathan Guinet committed
382
    while (!it.IsAtEnd())
383
      {
Jonathan Guinet's avatar
Jonathan Guinet committed
384
      if (maskFlag)
385
        {
Jonathan Guinet's avatar
Jonathan Guinet committed
386
        selectSample = (m_MaskIt.Get() > 0);
387
388
        ++m_MaskIt;
        }
Jonathan Guinet's avatar
Jonathan Guinet committed
389
      else selectSample = true;
390
391

      if (selectSample)
Jonathan Guinet's avatar
Jonathan Guinet committed
392
        {
393
394
395
396
397
398
399
        totalSamples++;

        sample = it.Get();

        sampleList->PushBack(sample);

        for (unsigned int i = 0; i < nbComp; ++i)
400
          {
401
402
403
404
405
          if (min[i] > sample[i])
            {
            min[i] = sample[i];
            }
          if (max[i] < sample[i])
Jonathan Guinet's avatar
Jonathan Guinet committed
406
            {
407
            max[i] = sample[i];
Jonathan Guinet's avatar
Jonathan Guinet committed
408
            }
409
410
          }
        }
411
      ++it;
Jonathan Guinet's avatar
Jonathan Guinet committed
412
413
      }

414
415
    // Next, initialize centroids by random sampling in the generated
    // list of samples
Jonathan Guinet's avatar
Jonathan Guinet committed
416

417
418
    for (unsigned int classIndex = 0; classIndex < nbClasses; ++classIndex)
      {
419
420
      SampleType newCentroid = sampleList->GetMeasurementVector(randGen->GetIntegerVariate(sampleList->Size()-1));

421
422
      for (unsigned int compIndex = 0; compIndex < sampleSize; ++compIndex)
        {
423
        initialMeans[compIndex + classIndex * sampleSize] = newCentroid[compIndex];
424
425
        }
      }
426
    otbAppLogINFO(<< totalSamples << " samples will be used as estimator input." << std::endl);
Jonathan Guinet's avatar
Jonathan Guinet committed
427
428
429
430

    /*******************************************/
    /*           Learning                      */
    /*******************************************/
431
432
433
434

    otbAppLogINFO("-- LEARNING --" << std::endl);
    otbAppLogINFO("Initial centroids are: " << std::endl);

435
    message.str("");
436
    message << std::endl;
437
    for (unsigned int i = 0; i < nbClasses; i++)
Jonathan Guinet's avatar
Jonathan Guinet committed
438
439
      {
      message << "Class " << i << ": ";
440
      for (unsigned int j = 0; j < sampleSize; j++)
441
        {
442
        message << std::setw(8) << initialMeans[i * sampleSize + j] << "   ";
443
        }
Jonathan Guinet's avatar
Jonathan Guinet committed
444
      message << std::endl;
445
      }
Jonathan Guinet's avatar
Jonathan Guinet committed
446
447
    message << std::endl;
    GetLogger()->Info(message.str());
448
    message.str("");
449
    otbAppLogINFO("Starting optimization." << std::endl);
Jonathan Guinet's avatar
Jonathan Guinet committed
450
    EstimatorType::Pointer estimator = EstimatorType::New();
451

Jonathan Guinet's avatar
Jonathan Guinet committed
452
453
    TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
    treeGenerator->SetSample(sampleList);
454
455

    treeGenerator->SetBucketSize(10000);
Jonathan Guinet's avatar
Jonathan Guinet committed
456
    treeGenerator->Update();
457

Jonathan Guinet's avatar
Jonathan Guinet committed
458
459
    estimator->SetParameters(initialMeans);
    estimator->SetKdTree(treeGenerator->GetOutput());
460
461
462
    int maxIt = GetParameterInt("maxit");
    estimator->SetMaximumIteration(maxIt);
    estimator->SetCentroidPositionChangesThreshold(GetParameterFloat("ct"));
Jonathan Guinet's avatar
Jonathan Guinet committed
463
    estimator->StartOptimization();
464

Jonathan Guinet's avatar
Jonathan Guinet committed
465
    EstimatorType::ParametersType estimatedMeans = estimator->GetParameters();
466
467
468
469

    otbAppLogINFO("Optimization completed." );
    if (estimator->GetCurrentIteration() == maxIt)
      {
470
      otbAppLogWARNING("The estimator reached the maximum iteration number." << std::endl);
471
      }
472
    message.str("");
Jonathan Guinet's avatar
Jonathan Guinet committed
473
    message << "Estimated centroids are: " << std::endl;
474
    message << std::endl;
475
    for (unsigned int i = 0; i < nbClasses; i++)
Jonathan Guinet's avatar
Jonathan Guinet committed
476
477
      {
      message << "Class " << i << ": ";
478
      for (unsigned int j = 0; j < sampleSize; j++)
Jonathan Guinet's avatar
Jonathan Guinet committed
479
        {
480
        message << std::setw(8) << estimatedMeans[i * sampleSize + j] << "   ";
Jonathan Guinet's avatar
Jonathan Guinet committed
481
482
483
        }
      message << std::endl;
      }
484

Jonathan Guinet's avatar
Jonathan Guinet committed
485
486
487
488
    message << std::endl;
    message << "Learning completed." << std::endl;
    message << std::endl;
    GetLogger()->Info(message.str());
489

Jonathan Guinet's avatar
Jonathan Guinet committed
490
491
    /*******************************************/
    /*           Classification                */
492
493
    /*******************************************/
    otbAppLogINFO("-- CLASSIFICATION --" << std::endl);
494

495
496
    // Finally, update the KMeans filter
    KMeansFunctorType functor;
497

498
499
500
    for (unsigned int classIndex = 0; classIndex < nbClasses; ++classIndex)
      {
      SampleType centroid(sampleSize);
501

502
503
504
505
506
507
508
509
510
511
      for (unsigned int compIndex = 0; compIndex < sampleSize; ++compIndex)
        {
        centroid[compIndex] = estimatedMeans[compIndex + classIndex * sampleSize];
        }
      functor.AddCentroid(classIndex, centroid);
      }

    m_KMeansFilter = KMeansFilterType::New();
    m_KMeansFilter->SetFunctor(functor);
    m_KMeansFilter->SetInput(m_InImage);
512

513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    // optional saving option -> lut

    if (IsParameterEnabled("outmeans"))
      {
      std::ofstream file;
      file.open(GetParameterString("outmeans").c_str());
      for (unsigned int i = 0; i < nbClasses; i++)
        {

        for (unsigned int j = 0; j < sampleSize; j++)
          {
          file << std::setw(8) << estimatedMeans[i * sampleSize + j] << " ";
          }
        file << std::endl;
        }

      file.close();
      }

532
    SetParameterOutputImage("out", m_KMeansFilter->GetOutput());
533
534
535

  }

536
  // KMeans filter
537
538
  KMeansFilterType::Pointer           m_KMeansFilter;
  FloatVectorImageType::Pointer       m_InImage;
Jonathan Guinet's avatar
Jonathan Guinet committed
539
540

};
541

542

Jonathan Guinet's avatar
Jonathan Guinet committed
543
544
}
}
545

Jonathan Guinet's avatar
Jonathan Guinet committed
546
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)
547
548