otbKMeansClassification.cxx 18.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
20

21 22 23 24
#include "otbWrapperCompositeApplication.h"
#include "otbWrapperApplicationFactory.h"

#include "otbOGRDataToSamplePositionFilter.h"
25

26 27 28
namespace otb
{
namespace Wrapper
29 30
{

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
class KMeansApplicationBase : public CompositeApplication
{
public:
  /** Standard class typedefs. */
  typedef KMeansApplicationBase Self;
  typedef CompositeApplication Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkTypeMacro( KMeansApplicationBase, Superclass )

protected:
  void InitKMParams()
  {
    AddApplication("ImageEnvelope", "imgenvelop", "mean shift smoothing");
    AddApplication("PolygonClassStatistics", "polystats", "Polygon Class Statistics");
    AddApplication("SampleSelection", "select", "Sample selection");
    AddApplication("SampleExtraction", "extraction", "Sample extraction");

    AddApplication("TrainVectorClassifier", "training", "Model training");
    AddApplication("ComputeImagesStatistics", "imgstats", "Compute Images second order statistics");
    AddApplication("ImageClassifier", "classif", "Performs a classification of the input image");

    ShareParameter("in", "imgenvelop.in");
    ShareParameter("out", "classif.out");

    InitKMSampling();
    InitKMClassification();

    // init at the end cleanup
62
    AddParameter( ParameterType_Bool, "cleanup", "Temporary files cleaning" );
63 64
    SetParameterDescription( "cleanup",
                           "If activated, the application will try to clean all temporary files it created" );
65
    SetParameterInt("cleanup", 1);
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  }

  void InitKMSampling()
  {
    AddParameter(ParameterType_Int, "nc", "Number of classes");
    SetParameterDescription("nc", "Number of modes, which will be used to generate class membership.");
    SetDefaultParameterInt("nc", 5);

    AddParameter(ParameterType_Int, "ts", "Training set size");
    SetParameterDescription("ts", "Size of the training set (in pixels).");
    SetDefaultParameterInt("ts", 100);
    MandatoryOff("ts");

    AddParameter(ParameterType_Int, "maxit", "Maximum number of iterations");
    SetParameterDescription("maxit", "Maximum number of iterations for the learning step.");
    SetDefaultParameterInt("maxit", 1000);
    MandatoryOff("maxit");

    AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename");
    SetParameterDescription("outmeans", "Output text file containing centroid positions");
    MandatoryOff("outmeans");

    ShareKMSamplingParameters();
    ConnectKMSamplingParams();
  }

  void InitKMClassification()
  {
    ShareKMClassificationParams();
    ConnectKMClassificationParams();
  }

  void ShareKMSamplingParameters()
  {
    ShareParameter("ram", "polystats.ram");
    ShareParameter("sampler", "select.sampler");
    ShareParameter("vm", "polystats.mask", "Validity Mask",
      "Validity mask, only non-zero pixels will be used to estimate KMeans modes.");
  }

  void ShareKMClassificationParams()
  {
    ShareParameter("nodatalabel", "classif.nodatalabel", "Label mask value",
      "By default, hidden pixels will have the assigned label 0 in the output image. "
      "It's possible to define the label mask by another value, "
      "but be careful to not take a label from another class. "
112
      "This application initialize the labels from 0 to N-1, "
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
      "N is the number of class (defined by 'nc' parameter).");
  }

  void ConnectKMSamplingParams()
  {
    Connect("polystats.in", "imgenvelop.in");

    Connect("select.in", "polystats.in");
    Connect("select.vec", "polystats.vec");
    Connect("select.ram", "polystats.ram");

    Connect("extraction.in", "select.in");
    Connect("extraction.field", "select.field");
    Connect("extraction.vec", "select.out");
    Connect("extraction.ram", "polystats.ram");
  }

  void ConnectKMClassificationParams()
  {
    Connect("training.cfield", "extraction.field");
    Connect("training.io.stats","imgstats.out");

    Connect("classif.in", "imgenvelop.in");
    Connect("classif.model", "training.io.out");
    Connect("classif.ram", "polystats.ram");
    Connect("classif.imstat", "imgstats.out");
  }

  void ConnectKMClassificationMask()
  {
    otbAppLogINFO("Using input mask ...");
    Connect("select.mask", "polystats.mask");
    Connect("classif.mask", "select.mask");
  }

  void ComputeImageEnvelope(const std::string &vectorFileName)
  {
150
    GetInternalApplication("imgenvelop")->SetParameterString("out", vectorFileName);
151 152 153 154 155 156 157 158 159 160 161
    GetInternalApplication("imgenvelop")->ExecuteAndWriteOutput();
  }

  void ComputeAddField(const std::string &vectorFileName,
                       const std::string &fieldName)
  {
    otbAppLogINFO("add field in the layer ...");
    otb::ogr::DataSource::Pointer ogrDS;
    ogrDS = otb::ogr::DataSource::New(vectorFileName, otb::ogr::DataSource::Modes::Update_LayerUpdate);
    otb::ogr::Layer layer = ogrDS->GetLayer(0);

162 163 164 165 166 167 168 169 170
    OGRFieldDefn classField(fieldName.c_str(), OFTInteger);
    classField.SetWidth(classField.GetWidth());
    classField.SetPrecision(classField.GetPrecision());
    ogr::FieldDefn classFieldDefn(classField);
    layer.CreateField(classFieldDefn);

    otb::ogr::Layer::const_iterator it = layer.cbegin();
    otb::ogr::Layer::const_iterator itEnd = layer.cend();
    for( ; it!=itEnd ; ++it)
171
    {
172 173 174 175 176
      ogr::Feature dstFeature(layer.GetLayerDefn());
      dstFeature.SetFrom( *it, TRUE);
      dstFeature.SetFID(it->GetFID());
      dstFeature[fieldName].SetValue<int>(0);
      layer.SetFeature(dstFeature);
177 178 179 180 181 182 183 184 185 186 187 188
    }
    const OGRErr err = layer.ogr().CommitTransaction();
    if (err != OGRERR_NONE)
      itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << layer.ogr().GetName() << ".");
    ogrDS->SyncToDisk();
  }

  void ComputePolygonStatistics(const std::string &statisticsFileName,
                                const std::string &fieldName)
  {
    std::vector<std::string> fieldList = {fieldName};

189 190
    GetInternalApplication("polystats")->SetParameterStringList("field", fieldList);
    GetInternalApplication("polystats")->SetParameterString("out", statisticsFileName);
191 192 193 194 195 196 197 198 199 200

    ExecuteInternal("polystats");
  }

  void SelectAndExtractSamples(const std::string &statisticsFileName,
                               const std::string &fieldName,
                               const std::string &sampleFileName,
                               int NBSamples)
  {
    /* SampleSelection */
201
    GetInternalApplication("select")->SetParameterString("out", sampleFileName);
202 203

    UpdateInternalParameters("select");
204 205
    GetInternalApplication("select")->SetParameterString("instats", statisticsFileName);
    GetInternalApplication("select")->SetParameterString("field", fieldName);
206

207 208
    GetInternalApplication("select")->SetParameterString("strategy", "constant");
    GetInternalApplication("select")->SetParameterInt("strategy.constant.nb", NBSamples);
209 210

    if( IsParameterEnabled("rand"))
211
      GetInternalApplication("select")->SetParameterInt("rand", GetParameterInt("rand"));
212 213 214 215 216 217 218

    // select sample positions
    ExecuteInternal("select");

    /* SampleExtraction */
    UpdateInternalParameters("extraction");

219 220
    GetInternalApplication("extraction")->SetParameterString("outfield", "prefix");
    GetInternalApplication("extraction")->SetParameterString("outfield.prefix.name", "value_");
221 222 223 224 225 226 227 228 229 230

    // extract sample descriptors
    GetInternalApplication("extraction")->ExecuteAndWriteOutput();
  }

  void TrainKMModel(FloatVectorImageType *image,
                    const std::string &sampleTrainFileName,
                    const std::string &modelFileName)
  {
    std::vector<std::string> extractOutputList = {sampleTrainFileName};
231
    GetInternalApplication("training")->SetParameterStringList("io.vd", extractOutputList);
232 233 234 235 236 237 238 239 240 241 242 243
    UpdateInternalParameters("training");

    // set field names
    std::string selectPrefix = GetInternalApplication("extraction")->GetParameterString("outfield.prefix.name");
    unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
    std::vector<std::string> selectedNames;
    for( unsigned int i = 0; i < nbBands; i++ )
      {
      std::ostringstream oss;
      oss << i;
      selectedNames.push_back( selectPrefix + oss.str() );
      }
244
    GetInternalApplication("training")->SetParameterStringList("feat", selectedNames);
245

246
    GetInternalApplication("training")->SetParameterString("classifier", "sharkkm");
247
    GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.maxiter",
248
                                                        GetParameterInt("maxit"));
249
    GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k",
250
                                                        GetParameterInt("nc"));
251 252

    if( IsParameterEnabled("rand"))
253
      GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand"));
254 255
    GetInternalApplication("training")->GetParameterByKey("v")->SetActive(false);

256
    GetInternalApplication("training")->SetParameterString("io.out", modelFileName);
257 258 259 260 261

    ExecuteInternal( "training" );
    otbAppLogINFO("output model : " << GetInternalApplication("training")->GetParameterString("io.out"));
  }

262
  void ComputeImageStatistics( ImageBaseType * img,
263 264
                                               const std::string &imagesStatsFileName)
  {
265 266
    // std::vector<std::string> imageFileNameList = {imageFileName};
    GetInternalApplication("imgstats")->SetParameterImageBase("il", img);
267
    GetInternalApplication("imgstats")->SetParameterString("out", imagesStatsFileName);
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293

    ExecuteInternal( "imgstats" );
    otbAppLogINFO("image statistics file : " << GetInternalApplication("imgstats")->GetParameterString("out"));
  }


  void KMeansClassif()
  {
    ExecuteInternal( "classif" );
  }

  void CreateOutMeansFile(FloatVectorImageType *image,
                          const std::string &modelFileName,
                          unsigned int nbClasses)
  {
    if (IsParameterEnabled("outmeans"))
    {
      unsigned int nbBands = image->GetNumberOfComponentsPerPixel();
      unsigned int nbElements = nbClasses * nbBands;
      // get the line in model file that contains the centroids positions
      std::ifstream infile(modelFileName);
      if(!infile)
      {
        itkExceptionMacro(<< "File : " << modelFileName << " couldn't be opened");
      }

294
      // get the line with the centroids (starts with "2 ")
295 296 297
      std::string line, centroidLine;
      while(std::getline(infile,line))
      {
298 299
        if (line.size() > 2 && line[0] == '2' && line[1] == ' ')
          {
300
          centroidLine = line;
301 302
          break;
          }
303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
      }

      std::vector<std::string> centroidElm;
      boost::split(centroidElm,centroidLine,boost::is_any_of(" "));

      // remove the first elements, not the centroids positions
      int nbWord = centroidElm.size();
      int beginCentroid = nbWord-nbElements;
      centroidElm.erase(centroidElm.begin(), centroidElm.begin()+beginCentroid);

      // write in the output file
      std::ofstream outfile;
      outfile.open(GetParameterString("outmeans"));

      for (unsigned int i = 0; i < nbClasses; i++)
      {
        for (unsigned int j = 0; j < nbBands; j++)
        {
          outfile << std::setw(8) << centroidElm[i * nbBands + j] << " ";
        }
        outfile << std::endl;
      }
    }
  }

  class KMeansFileNamesHandler
    {
    public :
      KMeansFileNamesHandler(const std::string &outPath)
      {
        tmpVectorFile = outPath + "_imgEnvelope.shp";
        polyStatOutput = outPath + "_polyStats.xml";
        sampleOutput = outPath + "_sampleSelect.shp";
        modelFile = outPath + "_model.txt";
        imgStatOutput = outPath + "_imgstats.xml";
      }

      void clear()
      {
        RemoveFile(tmpVectorFile);
        RemoveFile(polyStatOutput);
        RemoveFile(sampleOutput);
        RemoveFile(modelFile);
        RemoveFile(imgStatOutput);
      }

      std::string tmpVectorFile;
      std::string polyStatOutput;
      std::string sampleOutput;
      std::string modelFile;
      std::string imgStatOutput;

    private:
      bool RemoveFile(const std::string &filePath)
      {
        bool res = true;
359
        if( itksys::SystemTools::FileExists( filePath ) )
360 361 362 363 364 365 366 367 368 369 370
          {
          size_t posExt = filePath.rfind( '.' );
          if( posExt != std::string::npos && filePath.compare( posExt, std::string::npos, ".shp" ) == 0 )
            {
            std::string shxPath = filePath.substr( 0, posExt ) + std::string( ".shx" );
            std::string dbfPath = filePath.substr( 0, posExt ) + std::string( ".dbf" );
            std::string prjPath = filePath.substr( 0, posExt ) + std::string( ".prj" );
            RemoveFile( shxPath );
            RemoveFile( dbfPath );
            RemoveFile( prjPath );
            }
371
          res = itksys::SystemTools::RemoveFile( filePath );
372 373 374 375 376 377 378 379 380 381 382 383 384 385
          if( !res )
            {
            //otbAppLogINFO( <<"Unable to remove file  "<<filePath );
            }
          }
        return res;
      }

    };

};


class KMeansClassification: public KMeansApplicationBase
386 387 388 389
{
public:
  /** Standard class typedefs. */
  typedef KMeansClassification Self;
390
  typedef KMeansApplicationBase Superclass;
391 392
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
393

394 395
  /** Standard macro */
  itkNewMacro(Self);
396

397
  itkTypeMacro(Self, Superclass);
398

399
private:
400
  void DoInit() override
401 402
  {
    SetName("KMeansClassification");
403 404
    SetDescription("Unsupervised KMeans image classification");

405
    SetDocName("Unsupervised KMeans image classification");
406
    SetDocLongDescription("Performs unsupervised KMeans image classification."
407 408 409
      "KMeansClassification is a composite application, "
      "using an existing training and classification application."
      "The SharkKMeans model is used.\n"
410 411
      "KMeansClassification application is only available if OTB is compiled with Shark support"
      "(CMake option OTB_USE_SHARK=ON)\n"
412 413 414 415 416
      "The steps of this composite application :\n"
        "1) ImageEnveloppe : create a shapefile (1 polygon),\n"
        "2) PolygonClassStatistics : compute the statistics,\n"
        "3) SampleSelection : select the samples by constant strategy in the shapefile "
            "(1000000 samples max),\n"
417
        "4) SamplesExtraction : extract the samples descriptors (update of SampleSelection output file),\n"
418 419
        "5) ComputeImagesStatistics : compute images second order statistics,\n"
        "6) TrainVectorClassifier : train the SharkKMeans model,\n"
420 421 422 423 424 425
        "7) ImageClassifier : performs the classification of the input image "
            "according to a model file.\n\n"
        "It's possible to choice random/periodic modes of the SampleSelection application.\n"
        "If you want keep the temporary files (sample selected, model file, ...), "
        "initialize cleanup parameter.\n"
        "For more information on shark KMeans algorithm [1].");
426

427
    SetDocLimitations("The application doesn't support NaN in the input image");
428
    SetDocAuthors("OTB-Team");
429 430 431
    SetDocSeeAlso("ImageEnveloppe PolygonClassStatistics SampleSelection SamplesExtraction "
      "PolygonClassStatistics TrainVectorClassifier ImageClassifier\n"
      "[1] http://image.diku.dk/shark/sphinx_pages/build/html/rest_sources/tutorials/algorithms/kmeans.html");
432

433
    AddDocTag(Tags::Learning);
434 435 436 437 438
    AddDocTag(Tags::Segmentation);

    // Perform initialization
    ClearApplications();

439
    // initialisation parameters and synchronizes parameters
440
    Superclass::InitKMParams();
441 442 443

    AddRANDParameter();

444
    // Doc example parameter settings
445
    SetDocExampleParameterValue("in", "QB_1_ortho.tif");
446
    SetDocExampleParameterValue("ts", "1000");
447
    SetDocExampleParameterValue("nc", "5");
448
    SetDocExampleParameterValue("maxit", "1000");
449
    SetDocExampleParameterValue("out", "ClassificationFilterOutput.tif uint8");
450

451
    SetOfficialDocLink();
452
  }
453

454
  void DoUpdateParameters() override
455 456 457
  {
  }

458
  void DoExecute() override
459
  {
460
    if (IsParameterEnabled("vm") && HasValue("vm")) Superclass::ConnectKMClassificationMask();
461

462
    KMeansFileNamesHandler fileNames(GetParameterString("out"));
463

464
    const std::string fieldName = "field";
465

466
    // Create an image envelope
467
    Superclass::ComputeImageEnvelope(fileNames.tmpVectorFile);
468
    // Add a new field at the ImageEnvelope output file
469
    Superclass::ComputeAddField(fileNames.tmpVectorFile, fieldName);
470

471
    // Compute PolygonStatistics app
472
    UpdateKMPolygonClassStatisticsParameters(fileNames.tmpVectorFile);
473
    Superclass::ComputePolygonStatistics(fileNames.polyStatOutput, fieldName);
474

475 476 477 478 479
    // Compute number of sample max for KMeans
    const int theoricNBSamplesForKMeans = GetParameterInt("ts");
    const int upperThresholdNBSamplesForKMeans = 1000 * 1000;
    const int actualNBSamplesForKMeans = std::min(theoricNBSamplesForKMeans,
                                                  upperThresholdNBSamplesForKMeans);
480 481
    otbAppLogINFO(<< actualNBSamplesForKMeans << " is the maximum sample size that will be used." \
                  << std::endl);
482

483
    // Compute SampleSelection and SampleExtraction app
484 485 486
    Superclass::SelectAndExtractSamples(fileNames.polyStatOutput, fieldName,
                                        fileNames.sampleOutput,
                                        actualNBSamplesForKMeans);
487

488
    // Compute Images second order statistics
489
    Superclass::ComputeImageStatistics(GetParameterImageBase("in"), fileNames.imgStatOutput);
490

491
    // Compute a train model with TrainVectorClassifier app
492 493
    Superclass::TrainKMModel(GetParameterImage("in"), fileNames.sampleOutput,
                             fileNames.modelFile);
494

495
    // Compute a classification of the input image according to a model file
496
    Superclass::KMeansClassif();
497

498
    // Create the output text file containing centroids positions
499
    Superclass::CreateOutMeansFile(GetParameterImage("in"), fileNames.modelFile, GetParameterInt("nc"));
500 501

    // Remove all tempory files
502
    if( GetParameterInt( "cleanup" ) )
503
      {
504
      otbAppLogINFO( <<"Final clean-up ..." );
505
      fileNames.clear();
506
      }
507 508
  }

509 510
  void UpdateKMPolygonClassStatisticsParameters(const std::string &vectorFileName)
  {
511
    GetInternalApplication( "polystats" )->SetParameterString( "vec", vectorFileName);
512 513
    UpdateInternalParameters( "polystats" );
  }
514

515
};
516

517 518
}
}
519

520
OTB_APPLICATION_EXPORT(otb::Wrapper::KMeansClassification)
521