otbKMeansClassification.cxx 12.8 KB
Newer Older
Jonathan Guinet's avatar
Jonathan Guinet committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
/*=========================================================================

 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/

#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

22 23 24 25 26 27 28 29 30 31
#include "otbVectorImage.h"
#include "otbImage.h"
#include "itkEuclideanDistance.h"
#include "itkImageRegionSplitter.h"
#include "otbStreamingTraits.h"
#include "otbKMeansImageClassificationFilter.h"
#include "itkImageRegionConstIterator.h"
#include "itkListSample.h"
#include "itkWeightedCentroidKdTreeGenerator.h"
#include "itkKdTreeBasedKmeansEstimator.h"
32
#include "itkMersenneTwisterRandomVariateGenerator.h"
33

Jonathan Guinet's avatar
Jonathan Guinet committed
34 35 36
namespace otb
{
namespace Wrapper
37 38
{

39
typedef FloatImageType::PixelType PixelType;
40

Jonathan Guinet's avatar
Jonathan Guinet committed
41 42 43 44 45
typedef itk::FixedArray<PixelType, 108> SampleType;
typedef itk::Statistics::ListSample<SampleType> ListSampleType;
typedef itk::Statistics::WeightedCentroidKdTreeGenerator<ListSampleType> TreeGeneratorType;
typedef TreeGeneratorType::KdTreeType TreeType;
typedef itk::Statistics::KdTreeBasedKmeansEstimator<TreeType> EstimatorType;
46

Jonathan Guinet's avatar
Jonathan Guinet committed
47 48
typedef otb::StreamingTraits<FloatVectorImageType> StreamingTraitsType;
typedef itk::ImageRegionSplitter<2> SplitterType;
49
typedef FloatImageType::RegionType RegionType;
50

Jonathan Guinet's avatar
Jonathan Guinet committed
51
typedef itk::ImageRegionConstIterator<FloatVectorImageType> IteratorType;
52
typedef itk::ImageRegionConstIterator<UInt8ImageType> LabeledIteratorType;
53

54
typedef otb::KMeansImageClassificationFilter<FloatVectorImageType, UInt8ImageType, 108> ClassificationFilterType;
55

Jonathan Guinet's avatar
Jonathan Guinet committed
56 57 58 59 60 61 62 63
class KMeansClassification: public Application
{
public:
  /** Standard class typedefs. */
  typedef KMeansClassification Self;
  typedef Application Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
64

Jonathan Guinet's avatar
Jonathan Guinet committed
65 66
  /** Standard macro */
  itkNewMacro(Self);
67

Jonathan Guinet's avatar
Jonathan Guinet committed
68
  itkTypeMacro(KMeansClassification, otb::Application);
69

Jonathan Guinet's avatar
Jonathan Guinet committed
70 71 72 73
private:
  KMeansClassification()
  {
    SetName("KMeansClassification");
74 75 76 77 78 79 80
    SetDescription("Unsupervised KMeans image classification");

    SetDocName("Unsupervised KMeans image classification Application");
    SetDocLongDescription("Performs Unsupervised KMeans image classification.");
    SetDocLimitations("None");
    SetDocAuthors("OTB-Team");
    SetDocSeeAlso(" ");
81
  
82 83
    AddDocTag(Tags::Segmentation);
    AddDocTag(Tags::Learning);
Jonathan Guinet's avatar
Jonathan Guinet committed
84
  }
85

Jonathan Guinet's avatar
Jonathan Guinet committed
86 87 88
  virtual ~KMeansClassification()
  {
  }
89

Jonathan Guinet's avatar
Jonathan Guinet committed
90 91
  void DoCreateParameters()
  {
92

Jonathan Guinet's avatar
Jonathan Guinet committed
93
    AddParameter(ParameterType_InputImage, "in", "Input Image");
94
    SetParameterDescription("in","Input image filename.");
Jonathan Guinet's avatar
Jonathan Guinet committed
95
    AddParameter(ParameterType_OutputImage, "out", "Output Image");
96
    SetParameterDescription("out","Output image filename.");
97 98 99
    AddParameter(ParameterType_RAM, "ram", "Available RAM");
    SetDefaultParameterInt("ram", 256);
    MandatoryOff("ram");
Jonathan Guinet's avatar
Jonathan Guinet committed
100
    AddParameter(ParameterType_InputImage, "vm", "Validity Mask");
101 102 103
    SetParameterDescription("vm","Validity mask. Only non-zero pixels will be used to estimate KMeans modes.");
    AddParameter(ParameterType_Int, "ts", "Training set size");
    SetParameterDescription("ts", "Size of the training set.");
104
    SetDefaultParameterInt("ts", 100);
105 106
    AddParameter(ParameterType_Float, "tp", "Training set sample selection probability");
    SetParameterDescription("tp", "Probability for a sample to be selected in the training set.");
107
    SetDefaultParameterFloat("tp", 0.5);
Jonathan Guinet's avatar
Jonathan Guinet committed
108
    AddParameter(ParameterType_Int, "nc", "Number of classes");
109
    SetParameterDescription("nc","number of modes, which will be used to generate class membership.");
110
    SetDefaultParameterInt("nc", 3);
111 112
    AddParameter(ParameterType_Float, "cp", "Initial class centroid probability");
    SetParameterDescription("cp", "Probability for a pixel to be selected as an initial class centroid");
113
    SetDefaultParameterFloat("cp", 0.8);
Jonathan Guinet's avatar
Jonathan Guinet committed
114
    AddParameter(ParameterType_Int, "sl", "Number of lines for each streaming block");
115
    SetParameterDescription("sl","input image will be divided into sl lines.");
116
    SetDefaultParameterInt("sl", 1000);
117

118
    // Doc example parameter settings
119 120 121 122 123 124 125 126
    SetDocExampleParameterValue("in", "poupees_sub.png");
    SetDocExampleParameterValue("vm", "mask_KMeans.png");
    SetDocExampleParameterValue("ts", "100");
    SetDocExampleParameterValue("tp", "0.5");
    SetDocExampleParameterValue("nc", "5");
    SetDocExampleParameterValue("cp", "0.9");
    SetDocExampleParameterValue("sl", "100");
    SetDocExampleParameterValue("out", "ClassificationFilterOuptut.tif");
Jonathan Guinet's avatar
Jonathan Guinet committed
127
  }
128

Jonathan Guinet's avatar
Jonathan Guinet committed
129
  void DoUpdateParameters()
130
  {
Jonathan Guinet's avatar
Jonathan Guinet committed
131
    // Nothing to do here : all parameters are independent
132 133
  }

Jonathan Guinet's avatar
Jonathan Guinet committed
134
  void DoExecute()
135
  {
Jonathan Guinet's avatar
Jonathan Guinet committed
136 137 138 139 140 141
    GetLogger()->Debug("Entering DoExecute\n");

    // initiating random number generation
    itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer
        randomGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New();
    m_InImage = GetParameterImage("in");
142 143 144
    std::cout<<"mask in progress"<<std::endl;

    UInt8ImageType::Pointer maskImage = GetParameterUInt8Image("vm");
Jonathan Guinet's avatar
Jonathan Guinet committed
145

146
    std::cout<<"mask in progress done"<<std::endl;
Jonathan Guinet's avatar
Jonathan Guinet committed
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
    std::ostringstream message("");

    const unsigned int nbsamples = GetParameterInt("ts");
    const double trainingProb = GetParameterFloat("tp");
    const double initProb = GetParameterFloat("cp");
    const unsigned int nb_classes = GetParameterInt("nc");
    const unsigned int nbLinesForStreaming = GetParameterInt("sl");

    /*******************************************/
    /*           Sampling data                 */
    /*******************************************/
    GetLogger()->Info("-- SAMPLING DATA --");

    // Update input images information
    m_InImage->UpdateOutputInformation();
    maskImage->UpdateOutputInformation();

    if (m_InImage->GetLargestPossibleRegion() != maskImage->GetLargestPossibleRegion())
      {
      GetLogger()->Error("Mask image and input image have different sizes.");
      }
168

Jonathan Guinet's avatar
Jonathan Guinet committed
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    RegionType largestRegion = m_InImage->GetLargestPossibleRegion();

    // Setting up local streaming capabilities
    SplitterType::Pointer splitter = SplitterType::New();
    unsigned int
        numberOfStreamDivisions = StreamingTraitsType::CalculateNumberOfStreamDivisions(
                                                                                        m_InImage,
                                                                                        largestRegion,
                                                                                        splitter,
                                                                                        otb::SET_BUFFER_NUMBER_OF_LINES,
                                                                                        0, 0, nbLinesForStreaming);

    message.clear();
    message << "The images will be streamed into " << numberOfStreamDivisions << " parts.";
    GetLogger()->Info(message.str());

    // Training sample lists
    ListSampleType::Pointer sampleList = ListSampleType::New();
    EstimatorType::ParametersType initialMeans(108 * nb_classes);
    initialMeans.Fill(0);
    unsigned int init_means_index = 0;

    // Sample dimension and max dimension
    unsigned int maxDimension = SampleType::Dimension;
    unsigned int sampleSize = std::min(m_InImage->GetNumberOfComponentsPerPixel(), maxDimension);
    unsigned int totalSamples = 0;

    message.clear();
    message << "Sample max possible dimension: " << maxDimension << std::endl;
    GetLogger()->Info(message.str());
    message.clear();
    message << "The following sample size will be used: " << sampleSize << std::endl;
    GetLogger()->Info(message.str());
    // local streaming variables
    unsigned int piece = 0;
    RegionType streamingRegion;

    while ((totalSamples < nbsamples) && (init_means_index < 108 * nb_classes))
      {
      double random = randomGen->GetVariateWithClosedRange();
      piece = static_cast<unsigned int> (random * numberOfStreamDivisions);
210

Jonathan Guinet's avatar
Jonathan Guinet committed
211
      streamingRegion = splitter->GetSplit(piece, numberOfStreamDivisions, largestRegion);
212

Jonathan Guinet's avatar
Jonathan Guinet committed
213 214 215
      message.clear();
      message << "Processing region: " << streamingRegion << std::endl;
      GetLogger()->Info(message.str());
216

Jonathan Guinet's avatar
Jonathan Guinet committed
217 218 219
      m_InImage->SetRequestedRegion(streamingRegion);
      m_InImage->PropagateRequestedRegion();
      m_InImage->UpdateOutputData();
220

Jonathan Guinet's avatar
Jonathan Guinet committed
221 222 223
      maskImage->SetRequestedRegion(streamingRegion);
      maskImage->PropagateRequestedRegion();
      maskImage->UpdateOutputData();
224

Jonathan Guinet's avatar
Jonathan Guinet committed
225 226
      IteratorType it(m_InImage, streamingRegion);
      LabeledIteratorType m_MaskIt(maskImage, streamingRegion);
227

Jonathan Guinet's avatar
Jonathan Guinet committed
228 229
      it.GoToBegin();
      m_MaskIt.GoToBegin();
230

Jonathan Guinet's avatar
Jonathan Guinet committed
231
      unsigned int localNbSamples = 0;
232

Jonathan Guinet's avatar
Jonathan Guinet committed
233 234 235 236 237 238
      // Loop on the image
      while (!it.IsAtEnd() && !m_MaskIt.IsAtEnd() && (totalSamples < nbsamples)
          && (init_means_index < (108 * nb_classes)))
        {
        // If the current pixel is labeled
        if (m_MaskIt.Get() > 0)
239
          {
Jonathan Guinet's avatar
Jonathan Guinet committed
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
          if ((rand() < trainingProb * RAND_MAX))
            {
            SampleType newSample;

            // build the sample
            newSample.Fill(0);
            for (unsigned int i = 0; i < sampleSize; ++i)
              {
              newSample[i] = it.Get()[i];
              }
            // Update the the sample lists
            sampleList->PushBack(newSample);
            ++totalSamples;
            ++localNbSamples;
            }
          else
            if ((init_means_index < 108 * nb_classes) && (rand() < initProb * RAND_MAX))
              {
              for (unsigned int i = 0; i < sampleSize; ++i)
                {
                initialMeans[init_means_index + i] = it.Get()[i];
                }
              init_means_index += 108;
              }
264
          }
Jonathan Guinet's avatar
Jonathan Guinet committed
265 266
        ++it;
        ++m_MaskIt;
267
        }
Jonathan Guinet's avatar
Jonathan Guinet committed
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294

      message.clear();
      message << localNbSamples << " samples added to the training set." << std::endl;
      GetLogger()->Info(message.str());

      }

    message.clear();
    message << "The final training set contains " << totalSamples << " samples." << std::endl;
    GetLogger()->Info(message.str());

    message.clear();
    message << "Data sampling completed." << std::endl;
    GetLogger()->Info(message.str());

    /*******************************************/
    /*           Learning                      */
    /*******************************************/
    message.clear();
    message << "-- LEARNING --" << std::endl;
    message << "Initial centroids are: " << std::endl;
    GetLogger()->Info(message.str());
    message.clear();
    for (unsigned int i = 0; i < nb_classes; ++i)
      {
      message << "Class " << i << ": ";
      for (unsigned int j = 0; j < sampleSize; ++j)
295
        {
Jonathan Guinet's avatar
Jonathan Guinet committed
296
        message << initialMeans[i * 108 + j] << "\t";
297
        }
Jonathan Guinet's avatar
Jonathan Guinet committed
298
      message << std::endl;
299
      }
Jonathan Guinet's avatar
Jonathan Guinet committed
300
    message << std::endl;
301

Jonathan Guinet's avatar
Jonathan Guinet committed
302 303 304 305
    message.clear();
    message << "Starting optimization." << std::endl;
    message << std::endl;
    GetLogger()->Info(message.str());
306

Jonathan Guinet's avatar
Jonathan Guinet committed
307
    EstimatorType::Pointer estimator = EstimatorType::New();
308

Jonathan Guinet's avatar
Jonathan Guinet committed
309 310 311 312
    TreeGeneratorType::Pointer treeGenerator = TreeGeneratorType::New();
    treeGenerator->SetSample(sampleList);
    treeGenerator->SetBucketSize(100);
    treeGenerator->Update();
313

Jonathan Guinet's avatar
Jonathan Guinet committed
314 315 316 317 318
    estimator->SetParameters(initialMeans);
    estimator->SetKdTree(treeGenerator->GetOutput());
    estimator->SetMaximumIteration(100000000);
    estimator->SetCentroidPositionChangesThreshold(0.001);
    estimator->StartOptimization();
319

Jonathan Guinet's avatar
Jonathan Guinet committed
320 321 322 323 324
    EstimatorType::ParametersType estimatedMeans = estimator->GetParameters();
    message.clear();
    message << "Optimization completed." << std::endl;
    message << std::endl;
    message << "Estimated centroids are: " << std::endl;
325

Jonathan Guinet's avatar
Jonathan Guinet committed
326 327 328 329 330 331 332 333 334
    for (unsigned int i = 0; i < nb_classes; ++i)
      {
      message << "Class " << i << ": ";
      for (unsigned int j = 0; j < sampleSize; ++j)
        {
        message << estimatedMeans[i * 108 + j] << "\t";
        }
      message << std::endl;
      }
335

Jonathan Guinet's avatar
Jonathan Guinet committed
336 337 338 339
    message << std::endl;
    message << "Learning completed." << std::endl;
    message << std::endl;
    GetLogger()->Info(message.str());
340

Jonathan Guinet's avatar
Jonathan Guinet committed
341 342 343 344 345 346 347
    /*******************************************/
    /*           Classification                */
    /*******************************************/
    message.clear();
    message << "-- CLASSIFICATION --" << std::endl;
    message << std::endl;
    GetLogger()->Info(message.str());
348

Jonathan Guinet's avatar
Jonathan Guinet committed
349
    m_Classifier = ClassificationFilterType::New();
350

Jonathan Guinet's avatar
Jonathan Guinet committed
351 352
    m_Classifier->SetInput(m_InImage);
    m_Classifier->SetInputMask(maskImage);
353

Jonathan Guinet's avatar
Jonathan Guinet committed
354
    m_Classifier->SetCentroids(estimator->GetParameters());
355

356
    SetParameterOutputImage<UInt8ImageType> ("out", m_Classifier->GetOutput());
357 358 359

  }

Jonathan Guinet's avatar
Jonathan Guinet committed
360 361 362
  ClassificationFilterType::Pointer m_Classifier;
  FloatVectorImageType::Pointer m_InImage;

363

Jonathan Guinet's avatar
Jonathan Guinet committed
364
};
365

Jonathan Guinet's avatar
Jonathan Guinet committed
366 367
}
}
368

Jonathan Guinet's avatar
Jonathan Guinet committed
369
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)
370 371