otbTrainRegression.cxx 20.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
/*=========================================================================
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
#include "otbLearningApplicationBase.h"
#include "otbWrapperApplicationFactory.h"

#include "otbListSampleGenerator.h"

#include "otbImageToEnvelopeVectorDataFilter.h"
#include "itkPreOrderTreeIterator.h"

// Statistic XML Reader
#include "otbStatisticsXMLFileReader.h"

#include "itkTimeProbe.h"
#include "otbStandardFilterWatcher.h"

// Normalize the samples
#include "otbShiftScaleSampleListFilter.h"

// List sample concatenation
#include "otbConcatenateSampleListFilter.h"

// Balancing ListSample
#include "otbListSampleToBalancedListSampleFilter.h"

Guillaume Pasero's avatar
Guillaume Pasero committed
40 41
#include "itkMersenneTwisterRandomVariateGenerator.h"

42 43 44 45 46 47 48 49
// Elevation handler
#include "otbWrapperElevationParametersHandler.h"

namespace otb
{
namespace Wrapper
{

50
class TrainRegression: public LearningApplicationBase<float,float>
51 52 53 54
{
public:
  /** Standard class typedefs. */
  typedef TrainRegression Self;
55
  typedef LearningApplicationBase<float,float> Superclass;
56 57 58 59 60 61 62
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self)

  itkTypeMacro(TrainRegression, otb::Wrapper::LearningApplicationBase)
63

64 65 66 67
  typedef Superclass::SampleType              SampleType;
  typedef Superclass::ListSampleType          ListSampleType;
  typedef Superclass::TargetSampleType        TargetSampleType;
  typedef Superclass::TargetListSampleType    TargetListSampleType;
68

69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84
  typedef Superclass::SampleImageType         SampleImageType;
  typedef SampleImageType::PixelType          PixelType;

  // SampleList manipulation
  typedef otb::ListSampleGenerator<SampleImageType, VectorDataType> ListSampleGeneratorType;

  typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType;
  typedef otb::Statistics::ConcatenateSampleListFilter<TargetListSampleType> ConcatenateLabelListSampleFilterType;

  // Statistic XML file Reader
  typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;

  // Enhance List Sample  typedef otb::Statistics::ListSampleToBalancedListSampleFilter<ListSampleType, LabelListSampleType>      BalancingListSampleFilterType;
  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;

  typedef otb::ImageToEnvelopeVectorDataFilter<SampleImageType,VectorDataType> EnvelopeFilterType;
85

86 87
  typedef itk::PreOrderTreeIterator<VectorDataType::DataTreeType> TreeIteratorType;

Guillaume Pasero's avatar
Guillaume Pasero committed
88 89
  typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType;

90 91 92 93 94 95 96 97
protected:
  TrainRegression()
    {
    this->m_RegressionFlag = true;
    }

private:

98
void DoInit() ITK_OVERRIDE
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
{
  SetName("TrainRegression");
  SetDescription(
    "Train a classifier from multiple images to perform regression.");

  // Documentation
  SetDocName("Train a regression model");
  SetDocLongDescription(
    "This application trains a classifier from multiple input images or a csv "
    "file, in order to perform regression. Predictors are composed of pixel "
    "values in each band optionally centered and reduced using an XML "
    "statistics file produced by the ComputeImagesStatistics application.\n "
    "The output value for each predictor is assumed to be the last band "
    "(or the last column for CSV files). Training and validation predictor "
    "lists are built such that their size is inferior to maximum bounds given "
    "by the user, and the proportion corresponds to the balance parameter. "
    "Several classifier parameters can be set depending on the chosen "
    "classifier. In the validation process, the mean square error is computed\n"
    " This application is based on LibSVM and on OpenCV Machine Learning "
    "classifiers, and is compatible with OpenCV 2.3.1 and later.");
  SetDocLimitations("None");
  SetDocAuthors("OTB-Team");
  SetDocSeeAlso("OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html ");

  //Group IO
  AddParameter(ParameterType_Group, "io", "Input and output data");
125
  SetParameterDescription("io", "This group of parameters allows setting input and output data.");
126
  AddParameter(ParameterType_InputImageList, "io.il", "Input Image List");
Guillaume Pasero's avatar
Guillaume Pasero committed
127
  SetParameterDescription("io.il", "A list of input images. First (n-1) bands should contain the predictor. The last band should contain the output value to predict.");
128 129
  AddParameter(ParameterType_InputFilename, "io.csv", "Input CSV file");
  SetParameterDescription("io.csv","Input CSV file containing the predictors, and the output values in last column. Only used when no input image is given");
130
  MandatoryOff("io.csv");
131

132 133 134 135
  AddParameter(ParameterType_InputFilename, "io.imstat", "Input XML image statistics file");
  MandatoryOff("io.imstat");
  SetParameterDescription("io.imstat",
                          "Input XML file containing the mean and the standard deviation of the input images.");
Guillaume Pasero's avatar
Guillaume Pasero committed
136
  AddParameter(ParameterType_OutputFilename, "io.out", "Output regression model");
137
  SetParameterDescription("io.out", "Output file containing the model estimated (.txt format).");
138

139 140 141 142 143 144 145 146
  AddParameter(ParameterType_Float,"io.mse","Mean Square Error");
  SetParameterDescription("io.mse","Mean square error computed with the validation predictors");
  SetParameterRole("io.mse",Role_Output);
  DisableParameter("io.mse");

  //Group Sample list
  AddParameter(ParameterType_Group, "sample", "Training and validation samples parameters");
  SetParameterDescription("sample",
147
                          "This group of parameters allows you to set training and validation sample lists parameters.");
148 149 150 151 152

  AddParameter(ParameterType_Int, "sample.mt", "Maximum training predictors");
  //MandatoryOff("mt");
  SetDefaultParameterInt("sample.mt", 1000);
  SetParameterDescription("sample.mt", "Maximum number of training predictors (default = 1000) (no limit = -1).");
153

154 155 156 157 158 159 160 161 162 163 164 165 166
  AddParameter(ParameterType_Int, "sample.mv", "Maximum validation predictors");
  // MandatoryOff("mv");
  SetDefaultParameterInt("sample.mv", 1000);
  SetParameterDescription("sample.mv", "Maximum number of validation predictors (default = 1000) (no limit = -1).");

  AddParameter(ParameterType_Float, "sample.vtr", "Training and validation sample ratio");
  SetParameterDescription("sample.vtr",
                          "Ratio between training and validation samples (0.0 = all training, 1.0 = all validation) (default = 0.5).");
  SetParameterFloat("sample.vtr", 0.5);

  Superclass::DoInit();

  AddRANDParameter();
167 168 169 170 171

  // Doc example parameter settings
  SetDocExampleParameterValue("io.il", "training_dataset.tif");
  SetDocExampleParameterValue("io.out", "regression_model.txt");
  SetDocExampleParameterValue("io.imstat", "training_statistics.xml");
172
  SetDocExampleParameterValue("classifier", "libsvm");
173 174
}

175
void DoUpdateParameters() ITK_OVERRIDE
176
{
Guillaume Pasero's avatar
Guillaume Pasero committed
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
  if (HasValue("io.csv") && IsParameterEnabled("io.csv"))
    {
    MandatoryOff("io.il");
    }
  else
    {
    MandatoryOn("io.il");
    }
}

void ParseCSVPredictors(std::string path, ListSampleType* outputList)
{
  std::ifstream ifs;
  ifs.open(path.c_str());
  unsigned int nbCols = 0;
  char sep = '\t';
  std::istringstream iss;
  SampleType elem;
  while(!ifs.eof())
    {
    std::string line;
    std::getline(ifs,line);
    // filter current line
    while (!line.empty() && (line[0] == ' ' || line[0] == '\t'))
      {
      line.erase(line.begin());
      }
    while (!line.empty() && ( *(line.end()-1) == ' ' || *(line.end()-1) == '\t' || *(line.end()-1) == '\r'))
      {
      line.erase(line.end()-1);
      }

    // Avoid commented lines or too short ones
    if (!line.empty() && line[0] != '#')
      {
      std::vector<itksys::String> words = itksys::SystemTools::SplitString(line.c_str(),sep);
      if (nbCols == 0)
        {
        // detect separator and feature size
        if (words.size() < 2)
          {
          sep = ' ';
          words = itksys::SystemTools::SplitString(line.c_str(),sep);
          }
        if (words.size() < 2)
          {
          sep = ';';
          words = itksys::SystemTools::SplitString(line.c_str(),sep);
          }
        if (words.size() < 2)
          {
          sep = ',';
          words = itksys::SystemTools::SplitString(line.c_str(),sep);
          }
        if (words.size() < 2)
          {
233
          otbAppLogFATAL(<< "Can't parse CSV file : less than 2 columns or invalid separator (valid separators are tab, space, comma and semi-colon)");
Guillaume Pasero's avatar
Guillaume Pasero committed
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
          }
        nbCols = words.size();
        elem.SetSize(nbCols,false);
        outputList->SetMeasurementVectorSize(nbCols);
        }
      else if (words.size() != nbCols )
        {
        otbAppLogWARNING(<< "Skip CSV line, wrong number of columns : got "<<words.size() << ", expected "<<nbCols);
        continue;
        }
      elem.Fill(0.0);
      for (unsigned int i=0 ; i<nbCols ; ++i)
        {
        iss.str(words[i]);
        iss >> elem[i];
        }
      outputList->PushBack(elem);
      }
    }
  ifs.close();
254 255
}

256
void DoExecute() ITK_OVERRIDE
257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
{
  GetLogger()->Debug("Entering DoExecute\n");
  //Create training and validation for list samples and label list samples
  ConcatenateListSampleFilterType::Pointer concatenateTrainingSamples = ConcatenateListSampleFilterType::New();
  ConcatenateListSampleFilterType::Pointer concatenateValidationSamples = ConcatenateListSampleFilterType::New();

  SampleType meanMeasurementVector;
  SampleType stddevMeasurementVector;

  //--------------------------
  // Load measurements from images
  unsigned int nbBands = 0;
  unsigned int nbFeatures = 0;
  //Iterate over all input images

  FloatVectorImageListType* imageList = GetParameterImageList("io.il");
273

274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
  //Iterate over all input images
  for (unsigned int imgIndex = 0; imgIndex < imageList->Size(); ++imgIndex)
    {
    FloatVectorImageType::Pointer image = imageList->GetNthElement(imgIndex);
    image->UpdateOutputInformation();

    if (imgIndex == 0)
      {
      nbBands = image->GetNumberOfComponentsPerPixel();
      nbFeatures = static_cast<unsigned int>(static_cast<int>(nbBands) - 1);
      if (nbBands < 2)
        {
        otbAppLogFATAL(<< "Need at least two bands per image, got "<<nbBands);
        }
      else
        {
        if (nbBands != image->GetNumberOfComponentsPerPixel())
          {
          otbAppLogFATAL(<< "Image has a different number of components than "
            "the first one, expected "<<nbBands<<", got "<< image->GetNumberOfComponentsPerPixel());
          }
        }
      }

    // Extract image envelope to feed in sampleGenerator
    EnvelopeFilterType::Pointer envelopeFilter = EnvelopeFilterType::New();
    envelopeFilter->SetInput(image);
    envelopeFilter->SetSamplingRate(0);
    if (!image->GetProjectionRef().empty())
      {
      envelopeFilter->SetOutputProjectionRef(image->GetProjectionRef());
      }

    // Setup the DEM Handler
    // otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev");
309

310
    envelopeFilter->Update();
311

312
    VectorDataType::Pointer envelope = envelopeFilter->GetOutput();
313

314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
    TreeIteratorType itVector(envelope->GetDataTree());
    for (itVector.GoToBegin(); !itVector.IsAtEnd(); ++itVector)
      {
      if (itVector.Get()->IsPolygonFeature())
        {
        itVector.Get()->SetFieldAsInt(std::string("class"),1);
        }
      }


    //Sample list generator
    ListSampleGeneratorType::Pointer sampleGenerator = ListSampleGeneratorType::New();

    sampleGenerator->SetInput(image);
    sampleGenerator->SetInputVectorData(envelope);

    sampleGenerator->SetClassKey("class");
    sampleGenerator->SetMaxTrainingSize(GetParameterInt("sample.mt"));
    sampleGenerator->SetMaxValidationSize(GetParameterInt("sample.mv"));
    sampleGenerator->SetValidationTrainingProportion(GetParameterFloat("sample.vtr"));
    sampleGenerator->SetBoundByMin(false);
    sampleGenerator->SetPolygonEdgeInclusion(true);

    sampleGenerator->Update();

    //Concatenate training and validation samples from the image
    concatenateTrainingSamples->AddInput(sampleGenerator->GetTrainingListSample());
    concatenateValidationSamples->AddInput(sampleGenerator->GetValidationListSample());
    }
Guillaume Pasero's avatar
Guillaume Pasero committed
343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393

  // if no input image, try CSV
  if (imageList->Size() == 0)
    {
    if (HasValue("io.csv") && IsParameterEnabled("io.csv"))
      {
      ListSampleType::Pointer csvListSample = ListSampleType::New();
      this->ParseCSVPredictors(this->GetParameterString("io.csv"), csvListSample);
      unsigned int totalCSVSize = csvListSample->Size();
      if (totalCSVSize == 0)
        {
        otbAppLogFATAL("No input image and empty CSV file. Missing input data");
        }
      nbBands = csvListSample->GetMeasurementVectorSize();
      nbFeatures = static_cast<unsigned int>(static_cast<int>(nbBands) - 1);
      ListSampleType::Pointer csvTrainListSample = ListSampleType::New();
      ListSampleType::Pointer csvValidListSample = ListSampleType::New();
      csvTrainListSample->SetMeasurementVectorSize(nbBands);
      csvValidListSample->SetMeasurementVectorSize(nbBands);
      double ratio = this->GetParameterFloat("sample.vtr");
      int trainSize = static_cast<int>(static_cast<double>(totalCSVSize)*(1.0-ratio));
      int validSize = static_cast<int>(static_cast<double>(totalCSVSize)*(ratio));
      if (trainSize > this->GetParameterInt("sample.mt"))
        {
        trainSize = this->GetParameterInt("sample.mt");
        }
      if (validSize > this->GetParameterInt("sample.mv"))
        {
        validSize = this->GetParameterInt("sample.mv");
        }
      double probaTrain = static_cast<double>(trainSize)/static_cast<double>(totalCSVSize);
      double probaValid = static_cast<double>(validSize)/static_cast<double>(totalCSVSize);

      RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::GetInstance();
      for (unsigned int i=0; i<totalCSVSize; ++i)
        {
        double random = randomGenerator->GetUniformVariate(0.0, 1.0);
        if (random < probaTrain)
          {
          csvTrainListSample->PushBack(csvListSample->GetMeasurementVector(i));
          }
        else if (random < probaTrain + probaValid)
          {
          csvValidListSample->PushBack(csvListSample->GetMeasurementVector(i));
          }
        }
      concatenateTrainingSamples->AddInput(csvTrainListSample);
      concatenateValidationSamples->AddInput(csvValidListSample);
      }
    }

394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
  // Update
  concatenateTrainingSamples->Update();
  concatenateValidationSamples->Update();

  if (concatenateTrainingSamples->GetOutput()->Size() == 0)
    {
    otbAppLogFATAL("No training samples, cannot perform training.");
    }

  if (concatenateValidationSamples->GetOutput()->Size() == 0)
    {
    otbAppLogWARNING("No validation samples.");
    }

  if (IsParameterEnabled("io.imstat"))
    {
    StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
    statisticsReader->SetFileName(GetParameterString("io.imstat"));
    meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
    stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
414 415 416 417 418 419 420 421
    // handle stat file without output normalization
    if (meanMeasurementVector.Size() == nbFeatures)
      {
      meanMeasurementVector.SetSize(nbBands,false);
      meanMeasurementVector[nbFeatures] = 0.0;
      stddevMeasurementVector.SetSize(nbBands,false);
      stddevMeasurementVector[nbFeatures] = 1.0;
      }
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458
    }
  else
    {
    meanMeasurementVector.SetSize(nbBands);
    meanMeasurementVector.Fill(0.);
    stddevMeasurementVector.SetSize(nbBands);
    stddevMeasurementVector.Fill(1.);
    }

  // Shift scale the samples
  ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
  trainingShiftScaleFilter->SetInput(concatenateTrainingSamples->GetOutput());
  trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
  trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
  trainingShiftScaleFilter->Update();

  ListSampleType::Pointer rawValidationListSample=ListSampleType::New();

  //Test if the validation test is empty
  if ( concatenateValidationSamples->GetOutput()->Size() != 0 )
    {
    ShiftScaleFilterType::Pointer validationShiftScaleFilter = ShiftScaleFilterType::New();
    validationShiftScaleFilter->SetInput(concatenateValidationSamples->GetOutput());
    validationShiftScaleFilter->SetShifts(meanMeasurementVector);
    validationShiftScaleFilter->SetScales(stddevMeasurementVector);
    validationShiftScaleFilter->Update();
    rawValidationListSample = validationShiftScaleFilter->GetOutput();
    }

  // Split between predictors and output values
  ListSampleType::Pointer rawlistSample = trainingShiftScaleFilter->GetOutput();
  ListSampleType::Pointer listSample = ListSampleType::New();
  listSample->SetMeasurementVectorSize(nbFeatures);
  listSample->Resize(rawlistSample->Size());
  TargetListSampleType::Pointer labelListSample = TargetListSampleType::New();
  labelListSample->SetMeasurementVectorSize(1);
  labelListSample->Resize(rawlistSample->Size());
459

460 461 462 463 464 465
  ListSampleType::Pointer validationListSample = ListSampleType::New();
  validationListSample->SetMeasurementVectorSize(nbFeatures);
  validationListSample->Resize(rawValidationListSample->Size());
  TargetListSampleType::Pointer validationLabeledListSample = TargetListSampleType::New();
  validationLabeledListSample->SetMeasurementVectorSize(1);
  validationLabeledListSample->Resize(rawValidationListSample->Size());
466

467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484
  ListSampleType::MeasurementVectorType elem;
  TargetListSampleType::MeasurementVectorType outElem;
  for (ListSampleType::InstanceIdentifier i=0; i<rawlistSample->Size() ; ++i)
    {
    elem = rawlistSample->GetMeasurementVector(i);
    outElem[0] = elem[nbFeatures];
    labelListSample->SetMeasurementVector(i,outElem);
    elem.SetSize(nbFeatures,false);
    listSample->SetMeasurementVector(i,elem);
    }
  for (ListSampleType::InstanceIdentifier i=0; i<rawValidationListSample->Size() ; ++i)
    {
    elem = rawValidationListSample->GetMeasurementVector(i);
    outElem[0] = elem[nbFeatures];
    validationLabeledListSample->SetMeasurementVector(i,outElem);
    elem.SetSize(nbFeatures,false);
    validationListSample->SetMeasurementVector(i,elem);
    }
485

486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

  otbAppLogINFO("Number of training samples: " << concatenateTrainingSamples->GetOutput()->Size());
  //--------------------------
  // Split the data set into training/validation set
  ListSampleType::Pointer trainingListSample = listSample;
  TargetListSampleType::Pointer trainingLabeledListSample = labelListSample;

  otbAppLogINFO("Size of training set: " << trainingListSample->Size());
  otbAppLogINFO("Size of validation set: " << validationListSample->Size());

  //--------------------------
  // Estimate model
  //--------------------------
  this->Train(trainingListSample,trainingLabeledListSample,GetParameterString("io.out"));

  //--------------------------
  // Performances estimation
  //--------------------------
  ListSampleType::Pointer performanceListSample;
505 506
  TargetListSampleType::Pointer predictedList = TargetListSampleType::New();
  predictedList->SetMeasurementVectorSize(1);
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
  TargetListSampleType::Pointer performanceLabeledListSample;

  //Test the input validation set size
  if(validationLabeledListSample->Size() != 0)
    {
    performanceListSample = validationListSample;
    performanceLabeledListSample = validationLabeledListSample;
    }
  else
    {
    otbAppLogWARNING("The validation set is empty. The performance estimation is done using the input training set in this case.");
    performanceListSample = trainingListSample;
    performanceLabeledListSample = trainingLabeledListSample;
    }

  this->Classify(performanceListSample, predictedList, GetParameterString("io.out"));

Guillaume Pasero's avatar
Guillaume Pasero committed
524
  otbAppLogINFO("Training performances");
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
  double mse=0.0;
  TargetListSampleType::MeasurementVectorType predictedElem;
  for (TargetListSampleType::InstanceIdentifier i=0; i<performanceListSample->Size() ; ++i)
    {
    outElem = performanceLabeledListSample->GetMeasurementVector(i);
    predictedElem = predictedList->GetMeasurementVector(i);
    mse += (outElem[0] - predictedElem[0]) * (outElem[0] - predictedElem[0]);
    }
  mse /= static_cast<double>(performanceListSample->Size());
  otbAppLogINFO("Mean Square Error = "<<mse);
  this->SetParameterFloat("io.mse",mse);
}

};

} // end namespace Wrapper
} // end namespace otb

OTB_APPLICATION_EXPORT(otb::Wrapper::TrainRegression)