otbKMeansClassification.cxx 16.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*=========================================================================

 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/

#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

22 23 24 25 26 27
#include "otbVectorImage.h"
#include "otbStreamingTraits.h"
#include "itkImageRegionConstIterator.h"
#include "itkListSample.h"
#include "itkWeightedCentroidKdTreeGenerator.h"
#include "itkKdTreeBasedKmeansEstimator.h"
28 29 30 31
#include "otbStreamingShrinkImageFilter.h"
#include "otbChangeLabelImageFilter.h"
#include "otbRAMDrivenStrippedStreamingManager.h"

32 33 34 35 36
#include "otbChangeLabelImageFilter.h"
#include "itkLabelToRGBImageFilter.h"
#include "otbReliefColormapFunctor.h"
#include "itkScalarToRGBColormapImageFilter.h"

37

38 39
namespace otb
{
40 41


42
namespace Wrapper
43 44
{

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
namespace Functor
{
template <class TSample, class TLabel> class KMeansFunctor
{
public:
  /** operator */
  TLabel operator ()(const TSample& sample) const
  {
    typename CentroidMapType::const_iterator it = m_CentroidsMap.begin();

    if (it == m_CentroidsMap.end())
      {
      return 0;
      }

    TLabel resp = it->first;
    double minDist = m_Distance->Evaluate(sample, it->second);
    ++it;

    while (it != m_CentroidsMap.end())
      {
      double dist = m_Distance->Evaluate(sample, it->second);

      if (dist < minDist)
        {
        resp = it->first;
        minDist = dist;
        }
      ++it;
      }
    return resp;
  }

  /** Add a new centroid */
  void AddCentroid(const TLabel& label, const TSample& centroid)
  {
    m_CentroidsMap[label] = centroid;
  }

  /** Constructor */
  KMeansFunctor() : m_CentroidsMap(), m_Distance()
  {
    m_Distance = DistanceType::New();
  }

  bool operator !=(const KMeansFunctor& other) const
  {
    return m_CentroidsMap != other.m_CentroidsMap;
  }

private:
  typedef std::map<TLabel, TSample>                   CentroidMapType;
97
  typedef itk::Statistics::EuclideanDistanceMetric<TSample> DistanceType;
98 99 100 101 102 103 104

  CentroidMapType m_CentroidsMap;
  typename DistanceType::Pointer m_Distance;
};
}


105
typedef FloatImageType::PixelType PixelType;
106 107
typedef UInt8ImageType   LabeledImageType;

108 109 110 111 112 113
typedef UInt8VectorImageType        VectorImageType;
typedef VectorImageType::PixelType  VectorPixelType;
typedef UInt8RGBImageType           RGBImageType;
typedef RGBImageType::PixelType     RGBPixelType;


114 115 116 117
typedef LabeledImageType::PixelType LabelType;


typedef FloatVectorImageType::PixelType                               SampleType;
118 119 120 121
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
122
typedef RAMDrivenStrippedStreamingManager<FloatVectorImageType> RAMDrivenStrippedStreamingManagerType;
123 124


125
typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType;
126 127 128 129 130 131 132 133 134 135 136
typedef itk::ImageRegionConstIterator<LabeledImageType> LabeledIteratorType;

typedef otb::StreamingShrinkImageFilter<FloatVectorImageType,
     FloatVectorImageType>              ImageSamplingFilterType;

typedef otb::StreamingShrinkImageFilter<LabeledImageType,
    UInt8ImageType>              MaskSamplingFilterType;
typedef Functor::KMeansFunctor<SampleType, LabelType> KMeansFunctorType;
typedef itk::UnaryFunctorImageFilter<FloatVectorImageType,
    LabeledImageType, KMeansFunctorType>     KMeansFilterType;

137

138 139 140 141 142 143 144 145 146 147 148 149 150 151
// Manual label LUT
 typedef otb::ChangeLabelImageFilter
 <LabeledImageType, VectorImageType>    ChangeLabelFilterType;

 // Continuous LUT mapping
  typedef itk::ScalarToRGBColormapImageFilter<LabeledImageType, RGBImageType>      ColorMapFilterType;


  typedef otb::Functor::ReliefColormapFunctor
   <LabelType, RGBPixelType>           ReliefColorMapFunctorType;

  typedef otb::ImageMetadataInterfaceBase ImageMetadataInterfaceType;


152 153 154 155 156 157 158 159
class KMeansClassification: public Application
{
public:
  /** Standard class typedefs. */
  typedef KMeansClassification Self;
  typedef Application Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
160

161 162
  /** Standard macro */
  itkNewMacro(Self);
163

164
  itkTypeMacro(KMeansClassification, otb::Application);
165

166
private:
167
  void DoInit()
168 169
  {
    SetName("KMeansClassification");
170 171
    SetDescription("Unsupervised KMeans image classification");

172
    SetDocName("Unsupervised KMeans image classification");
173
    SetDocLongDescription("Performs unsupervised KMeans image classification.");
174 175 176
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");
177

178 179
    AddDocTag(Tags::Segmentation);
    AddDocTag(Tags::Learning);
180
    AddParameter(ParameterType_InputImage, "in", "Input Image");
181
    SetParameterDescription("in", "Input image to classify.");
182
    AddParameter(ParameterType_OutputImage, "out", "Output Image");
183
    SetParameterDescription("out", "Output image containing the class indexes.");
184
    SetDefaultOutputPixelType("out",ImagePixelType_uint8);
185 186 187

    AddRAMParameter();

188
    AddParameter(ParameterType_InputImage, "vm", "Validity Mask");
189
    SetParameterDescription("vm", "Validity mask. Only non-zero pixels will be used to estimate KMeans modes.");
190
    MandatoryOff("vm");
191
    AddParameter(ParameterType_Int, "ts", "Training set size");
192
    SetParameterDescription("ts", "Size of the training set (in pixels).");
193
    SetDefaultParameterInt("ts", 100);
194
    MandatoryOff("ts");
195
    AddParameter(ParameterType_Int, "nc", "Number of classes");
196
    SetParameterDescription("nc", "Number of modes, which will be used to generate class membership.");
197
    SetDefaultParameterInt("nc", 5);
198
    AddParameter(ParameterType_Int, "maxit", "Maximum number of iterations");
199
    SetParameterDescription("maxit", "Maximum number of iterations for the learning step.");
200 201
    SetDefaultParameterInt("maxit", 1000);
    MandatoryOff("maxit");
202
    AddParameter(ParameterType_Float, "ct", "Convergence threshold");
203
    SetParameterDescription("ct", "Convergence threshold for class centroid  (L2 distance, by default 0.0001).");
204 205
    SetDefaultParameterFloat("ct", 0.0001);
    MandatoryOff("ct");
206
    AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename");
207
    SetParameterDescription("outmeans", "Output text file containing centroid positions");
208 209
    MandatoryOff("outmeans");

210
    AddRANDParameter();
211

212
    // Doc example parameter settings
213
    SetDocExampleParameterValue("in", "QB_1_ortho.tif");
214
    SetDocExampleParameterValue("ts", "1000");
215
    SetDocExampleParameterValue("nc", "5");
216
    SetDocExampleParameterValue("maxit", "1000");
217
    SetDocExampleParameterValue("ct", "0.0001");
218
    SetDocExampleParameterValue("out", "ClassificationFilterOutput.tif");
219
  }
220

221
  void DoUpdateParameters()
222
  {
223 224 225 226 227 228 229 230 231
    // test of input image //
    if (HasValue("in"))
      {
      // input image
      FloatVectorImageType::Pointer inImage = GetParameterImage("in");

      RAMDrivenStrippedStreamingManagerType::Pointer streamingManager = RAMDrivenStrippedStreamingManagerType::New();
      int availableRAM = GetParameterInt("ram");
      streamingManager->SetAvailableRAMInMB(availableRAM);
232
      float bias = 1.5; // empirical value
233 234 235 236 237 238 239 240 241 242 243 244
      streamingManager->SetBias(bias);
      FloatVectorImageType::RegionType largestRegion = inImage->GetLargestPossibleRegion();
      FloatVectorImageType::SizeType largestRegionSize = largestRegion.GetSize();
      streamingManager->PrepareStreaming(inImage, largestRegion);

      unsigned long nbDivisions = streamingManager->GetNumberOfSplits();
      unsigned long largestPixNb = largestRegionSize[0] * largestRegionSize[1];

      unsigned long maxPixNb = largestPixNb / nbDivisions;

      if (GetParameterInt("ts") > static_cast<int> (maxPixNb))
        {
245 246
        otbAppLogWARNING("The available RAM is too small to process this sample size of " << GetParameterInt("ts") <<
            " pixels. The sample size will be reduced to " << maxPixNb << " pixels." << std::endl);
247 248 249 250 251
        this->SetParameterInt("ts", maxPixNb);
        }

      this->SetMaximumParameterIntValue("ts", maxPixNb);
      }
252 253
  }

254
  void DoExecute()
255
  {
256 257 258
    GetLogger()->Debug("Entering DoExecute\n");

    m_InImage = GetParameterImage("in");
259
    m_InImage->UpdateOutputInformation();
260
    UInt8ImageType::Pointer maskImage;
261

262 263
    std::ostringstream message("");

264 265
    int nbsamples = GetParameterInt("ts");
    const unsigned int nbClasses = GetParameterInt("nc");
266 267 268

    /*******************************************/
    /*           Sampling data                 */
269 270 271
    /*******************************************/

    otbAppLogINFO("-- SAMPLING DATA --"<<std::endl);
272 273 274 275

    // Update input images information
    m_InImage->UpdateOutputInformation();

Jonathan Guinet's avatar
Jonathan Guinet committed
276 277 278 279 280 281 282 283 284
    bool maskFlag = IsParameterEnabled("vm");
    if (maskFlag)
      {
      otbAppLogINFO("sample choice using mask "<<std::endl);
      maskImage = GetParameterUInt8Image("vm");
      maskImage->UpdateOutputInformation();
      if (m_InImage->GetLargestPossibleRegion() != maskImage->GetLargestPossibleRegion())
        {
        GetLogger()->Error("Mask image and input image have different sizes.");
285
        return;
Jonathan Guinet's avatar
Jonathan Guinet committed
286 287
        }
      }
288

289 290
    // Training sample lists
    ListSampleType::Pointer sampleList = ListSampleType::New();
291
    sampleList->SetMeasurementVectorSize(m_InImage->GetNumberOfComponentsPerPixel());
292

293
    //unsigned int init_means_index = 0;
294 295

    // Sample dimension and max dimension
296 297
    const unsigned int nbComp = m_InImage->GetNumberOfComponentsPerPixel();
    unsigned int sampleSize = nbComp;
298 299
    unsigned int totalSamples = 0;

300
    // sampleSize = std::min(nbComp, maxDim);
301

302 303 304 305 306 307 308
    EstimatorType::ParametersType initialMeans(nbComp * nbClasses);
    initialMeans.Fill(0);

    // use image and mask shrink

    ImageSamplingFilterType::Pointer imageSampler = ImageSamplingFilterType::New();
    imageSampler->SetInput(m_InImage);
309

310
    double theoricNBSamplesForKMeans = nbsamples;
311

312 313
    const double upperThresholdNBSamplesForKMeans = 1000 * 1000;
    const double actualNBSamplesForKMeans = std::min(theoricNBSamplesForKMeans, upperThresholdNBSamplesForKMeans);
314

315
    otbAppLogINFO(<< actualNBSamplesForKMeans << " is the maximum sample size that will be used." << std::endl);
316

317 318 319 320 321 322
    const double shrinkFactor = vcl_floor(
                                          vcl_sqrt(
                                                   m_InImage->GetLargestPossibleRegion().GetNumberOfPixels()
                                                       / actualNBSamplesForKMeans));
    imageSampler->SetShrinkFactor(shrinkFactor);
    imageSampler->Update();
323

324 325
    MaskSamplingFilterType::Pointer maskSampler;
    LabeledIteratorType m_MaskIt;
Jonathan Guinet's avatar
Jonathan Guinet committed
326
    if (maskFlag)
327 328 329 330 331
      {
      maskSampler = MaskSamplingFilterType::New();
      maskSampler->SetInput(maskImage);
      maskSampler->SetShrinkFactor(shrinkFactor);
      maskSampler->Update();
Jonathan Guinet's avatar
Jonathan Guinet committed
332
      m_MaskIt = LabeledIteratorType(maskSampler->GetOutput(), maskSampler->GetOutput()->GetLargestPossibleRegion());
333 334
      m_MaskIt.GoToBegin();
      }
335
    // Then, build the sample list
336

337
    IteratorType it(imageSampler->GetOutput(), imageSampler->GetOutput()->GetLargestPossibleRegion());
338

339 340 341 342 343 344
    it.GoToBegin();

    SampleType min;
    SampleType max;
    SampleType sample;
    //first sample
345

346
    itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen=itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance();
347

348
    //randGen->Initialize();
349

Jonathan Guinet's avatar
Jonathan Guinet committed
350
    if (maskFlag)
351
      {
352
      while (!it.IsAtEnd() && !m_MaskIt.IsAtEnd() && (m_MaskIt.Get() <= 0))
Jonathan Guinet's avatar
Jonathan Guinet committed
353 354 355 356
        {
        ++it;
        ++m_MaskIt;
        }
357 358 359 360 361 362 363

      // If the mask is empty after the subsampling
      if (m_MaskIt.IsAtEnd())
        {
        GetLogger()->Error("The mask image is empty after subsampling. Please increase the training set size.");
        return;
        }
364
      }
365

366 367 368
    min = it.Get();
    max = it.Get();
    sample = it.Get();
369

370 371 372
    sampleList->PushBack(sample);

    ++it;
373

Jonathan Guinet's avatar
Jonathan Guinet committed
374
    if (maskFlag)
375
      {
Jonathan Guinet's avatar
Jonathan Guinet committed
376
      ++m_MaskIt;
377 378
      }

379
    totalSamples = 1;
380
    bool selectSample;
Jonathan Guinet's avatar
Jonathan Guinet committed
381
    while (!it.IsAtEnd())
382
      {
Jonathan Guinet's avatar
Jonathan Guinet committed
383
      if (maskFlag)
384
        {
Jonathan Guinet's avatar
Jonathan Guinet committed
385
        selectSample = (m_MaskIt.Get() > 0);
386 387
        ++m_MaskIt;
        }
Jonathan Guinet's avatar
Jonathan Guinet committed
388
      else selectSample = true;
389 390

      if (selectSample)
391
        {
392 393 394 395 396 397 398
        totalSamples++;

        sample = it.Get();

        sampleList->PushBack(sample);

        for (unsigned int i = 0; i < nbComp; ++i)
399
          {
400 401 402 403 404
          if (min[i] > sample[i])
            {
            min[i] = sample[i];
            }
          if (max[i] < sample[i])
405
            {
406
            max[i] = sample[i];
407
            }
408 409
          }
        }
410
      ++it;
411 412
      }

413 414
    // Next, initialize centroids by random sampling in the generated
    // list of samples
415

416 417
    for (unsigned int classIndex = 0; classIndex < nbClasses; ++classIndex)
      {
418 419
      SampleType newCentroid = sampleList->GetMeasurementVector(randGen->GetIntegerVariate(sampleList->Size()-1));

420 421
      for (unsigned int compIndex = 0; compIndex < sampleSize; ++compIndex)
        {
422
        initialMeans[compIndex + classIndex * sampleSize] = newCentroid[compIndex];
423 424
        }
      }
425
    otbAppLogINFO(<< totalSamples << " samples will be used as estimator input." << std::endl);
426 427 428 429

    /*******************************************/
    /*           Learning                      */
    /*******************************************/
430 431 432 433

    otbAppLogINFO("-- LEARNING --" << std::endl);
    otbAppLogINFO("Initial centroids are: " << std::endl);

434
    message.str("");
435
    message << std::endl;
436
    for (unsigned int i = 0; i < nbClasses; i++)
437 438
      {
      message << "Class " << i << ": ";
439
      for (unsigned int j = 0; j < sampleSize; j++)
440
        {
441
        message << std::setw(8) << initialMeans[i * sampleSize + j] << "   ";
442
        }
443
      message << std::endl;
444
      }
445 446
    message << std::endl;
    GetLogger()->Info(message.str());
447
    message.str("");
448
    otbAppLogINFO("Starting optimization." << std::endl);
449
    EstimatorType::Pointer estimator = EstimatorType::New();
450

451 452
    TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
    treeGenerator->SetSample(sampleList);
453 454

    treeGenerator->SetBucketSize(10000);
455
    treeGenerator->Update();
456

457 458
    estimator->SetParameters(initialMeans);
    estimator->SetKdTree(treeGenerator->GetOutput());
459 460 461
    int maxIt = GetParameterInt("maxit");
    estimator->SetMaximumIteration(maxIt);
    estimator->SetCentroidPositionChangesThreshold(GetParameterFloat("ct"));
462
    estimator->StartOptimization();
463

464
    EstimatorType::ParametersType estimatedMeans = estimator->GetParameters();
465 466 467 468

    otbAppLogINFO("Optimization completed." );
    if (estimator->GetCurrentIteration() == maxIt)
      {
469
      otbAppLogWARNING("The estimator reached the maximum iteration number." << std::endl);
470
      }
471
    message.str("");
472
    message << "Estimated centroids are: " << std::endl;
473
    message << std::endl;
474
    for (unsigned int i = 0; i < nbClasses; i++)
475 476
      {
      message << "Class " << i << ": ";
477
      for (unsigned int j = 0; j < sampleSize; j++)
478
        {
479
        message << std::setw(8) << estimatedMeans[i * sampleSize + j] << "   ";
480 481 482
        }
      message << std::endl;
      }
483

484 485 486 487
    message << std::endl;
    message << "Learning completed." << std::endl;
    message << std::endl;
    GetLogger()->Info(message.str());
488

489 490
    /*******************************************/
    /*           Classification                */
491 492
    /*******************************************/
    otbAppLogINFO("-- CLASSIFICATION --" << std::endl);
493

494 495
    // Finally, update the KMeans filter
    KMeansFunctorType functor;
496

497 498 499
    for (unsigned int classIndex = 0; classIndex < nbClasses; ++classIndex)
      {
      SampleType centroid(sampleSize);
500

501 502 503 504 505 506 507 508 509 510
      for (unsigned int compIndex = 0; compIndex < sampleSize; ++compIndex)
        {
        centroid[compIndex] = estimatedMeans[compIndex + classIndex * sampleSize];
        }
      functor.AddCentroid(classIndex, centroid);
      }

    m_KMeansFilter = KMeansFilterType::New();
    m_KMeansFilter->SetFunctor(functor);
    m_KMeansFilter->SetInput(m_InImage);
511

512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530
    // optional saving option -> lut

    if (IsParameterEnabled("outmeans"))
      {
      std::ofstream file;
      file.open(GetParameterString("outmeans").c_str());
      for (unsigned int i = 0; i < nbClasses; i++)
        {

        for (unsigned int j = 0; j < sampleSize; j++)
          {
          file << std::setw(8) << estimatedMeans[i * sampleSize + j] << " ";
          }
        file << std::endl;
        }

      file.close();
      }

531
    SetParameterOutputImage("out", m_KMeansFilter->GetOutput());
532 533 534

  }

535
  // KMeans filter
536 537
  KMeansFilterType::Pointer           m_KMeansFilter;
  FloatVectorImageType::Pointer       m_InImage;
538 539

};
540

541

542 543
}
}
544

545
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)
546 547