otbVectorClassifier.cxx 15.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
/*
 * Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
 *
 * This file is part of Orfeo Toolbox
 *
 *     https://www.orfeo-toolbox.org/
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

#include "otbOGRDataSourceWrapper.h"
#include "otbOGRFeatureWrapper.h"

#include "itkVariableLengthVector.h"
#include "otbStatisticsXMLFileReader.h"

#include "itkListSample.h"
#include "otbShiftScaleSampleListFilter.h"

#include "otbMachineLearningModelFactory.h"

#include "otbMachineLearningModel.h"

#include <time.h>

namespace otb
{
namespace Wrapper
{

/** Utility function to negate std::isalnum */
bool IsNotAlphaNum(char c)
  {
  return !std::isalnum(c);
  }

class VectorClassifier : public Application
{
public:
  /** Standard class typedefs. */
  typedef VectorClassifier              Self;
  typedef Application                   Superclass;
  typedef itk::SmartPointer<Self>       Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;

  /** Standard macro */
  itkNewMacro(Self);

  itkTypeMacro(Self, Application)

  /** Filters typedef */
65
  typedef float                                         ValueType;
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
  typedef unsigned int                                  LabelType;
  typedef itk::FixedArray<LabelType,1>                  LabelSampleType;
  typedef itk::Statistics::ListSample<LabelSampleType>  LabelListSampleType;

  typedef otb::MachineLearningModel<ValueType,LabelType>          MachineLearningModelType;
  typedef otb::MachineLearningModelFactory<ValueType, LabelType>  MachineLearningModelFactoryType;
  typedef MachineLearningModelType::Pointer                       ModelPointerType;
  typedef MachineLearningModelType::ConfidenceListSampleType      ConfidenceListSampleType;

  /** Statistics Filters typedef */
  typedef itk::VariableLengthVector<ValueType>                    MeasurementType;
  typedef otb::StatisticsXMLFileReader<MeasurementType>           StatisticsReader;

  typedef itk::VariableLengthVector<ValueType>                    InputSampleType;
  typedef itk::Statistics::ListSample<InputSampleType>            ListSampleType;
  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;

  ~VectorClassifier() ITK_OVERRIDE
    {
    MachineLearningModelFactoryType::CleanFactories();
    }

private:
  void DoInit() ITK_OVERRIDE
  {
    SetName("VectorClassifier");
92
    SetDescription("Performs a classification of the input vector data according to a model file.");
93 94 95

    SetDocName("Vector Classification");
    SetDocAuthors("OTB-Team");
96 97 98 99 100 101 102 103 104 105 106
    SetDocLongDescription("This application performs a vector data classification "
      "based on a model file produced by the TrainVectorClassifier application."
      "Features of the vector data output will contain the class labels decided by the classifier "
      "(maximal class label = 65535). \n"
      "There are two modes: \n"
        "1) Update mode: add of the 'cfield' field containing the predicted class in the input file. \n"
        "2) Write mode: copies the existing fields of the input file in the output file "
           " and add the 'cfield' field containing the predicted class. \n"
      "If you have declared the output file, the write mode applies. "
      "Otherwise, the input file update mode will be applied.");

107
    SetDocLimitations("Shapefiles are supported. But the SQLite format is only supported in update mode.");
108
    SetDocSeeAlso("TrainVectorClassifier");
109
    AddDocTag(Tags::Learning);
110 111

    AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data");
112
    SetParameterDescription("in","The input vector data file to classify.");
113 114 115

    AddParameter(ParameterType_InputFilename, "instat", "Statistics file");
    SetParameterDescription("instat", "A XML file containing mean and standard deviation to center"
116
      "and reduce samples before classification, produced by ComputeImagesStatistics application.");
117 118 119
    MandatoryOff("instat");

    AddParameter(ParameterType_InputFilename, "model", "Model file");
120
    SetParameterDescription("model", "Model file produced by TrainVectorClassifier application.");
121

122
    AddParameter(ParameterType_String,"cfield","Field class");
123
    SetParameterDescription("cfield","Field containing the predicted class."
124 125 126
      "Only geometries with this field available will be taken into account.\n"
      "The field is added either in the input file (if 'out' off) or in the output file.\n"
      "Caution, the 'cfield' must not exist in the input file if you are updating the file.");
127 128
    SetParameterString("cfield","predicted", false);

129 130 131
    AddParameter(ParameterType_ListView, "feat", "Field names to be calculated.");
    SetParameterDescription("feat","List of field names in the input vector data used as features for training. "
      "Put the same field names as the TrainVectorClassifier application.");
132 133

    AddParameter(ParameterType_Empty, "confmap",  "Confidence map");
134 135 136 137
    SetParameterDescription( "confmap", "Confidence map of the produced classification. "
      "The confidence index depends on the model : \n"
      "  - LibSVM : difference between the two highest probabilities "
           "(needs a model with probability estimates, so that classes probabilities can be computed for each sample)\n"
138 139 140 141 142 143 144
      "  - OpenCV\n"
      "    * Boost : sum of votes\n"
      "    * DecisionTree : (not supported)\n"
      "    * GradientBoostedTree : (not supported)\n"
      "    * KNearestNeighbors : number of neighbors with the same label\n"
      "    * NeuralNetwork : difference between the two highest responses\n"
      "    * NormalBayes : (not supported)\n"
145 146 147
      "    * RandomForest : Confidence (proportion of votes for the majority class). "
             "Margin (normalized difference of the votes of the 2 majority classes) is not available for now.\n"
      "    * SVM : distance to margin (only works for 2-class models).\n");
148 149 150 151 152 153 154 155 156 157 158
    MandatoryOff("confmap");

    AddParameter(ParameterType_OutputFilename, "out", "Output vector data file containing class labels");
    SetParameterDescription("out","Output vector data file storing sample values (OGR format)."
      "If not given, the input vector data file is updated.");
    MandatoryOff("out");

    // Doc example parameter settings
    SetDocExampleParameterValue("in", "vectorData.shp");
    SetDocExampleParameterValue("instat", "meanVar.xml");
    SetDocExampleParameterValue("model", "svmModel.svm");
159
    SetDocExampleParameterValue("out", "vectorDataLabeledVector.shp");
160 161
    SetDocExampleParameterValue("feat", "perimeter  area  width");
    SetDocExampleParameterValue("cfield", "predicted");
162

163
    SetOfficialDocLink();
164 165 166 167 168 169 170 171 172 173 174 175 176 177
  }

  void DoUpdateParameters() ITK_OVERRIDE
  {
    if ( HasValue("in") )
    {
      std::string shapefile = GetParameterString("in");

      otb::ogr::DataSource::Pointer ogrDS;

      OGRSpatialReference oSRS("");
      std::vector<std::string> options;

      ogrDS = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
178 179
      otb::ogr::Layer layer = ogrDS->GetLayer(0);
      OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
180 181 182

      ClearChoices("feat");

183
      for(int iField=0; iField< layerDefn.GetFieldCount(); iField++)
184
      {
185 186 187 188
        std::string item = layerDefn.GetFieldDefn(iField)->GetNameRef();
        std::string key(item);
        key.erase( std::remove_if(key.begin(),key.end(),IsNotAlphaNum), key.end());
        std::transform(key.begin(), key.end(), key.begin(), tolower);
189

190
        OGRFieldType fieldType = layerDefn.GetFieldDefn(iField)->GetType();
191 192
        if(fieldType == OFTInteger ||  ogr::version_proxy::IsOFTInteger64(fieldType) || fieldType == OFTReal)
          {
193
          std::string tmpKey="feat."+key;
194 195 196 197 198 199 200 201 202 203
          AddChoice(tmpKey,item);
          }
      }
    }
  }

  void DoExecute() ITK_OVERRIDE
  {
    clock_t tic = clock();

204
    std::string shapefile = GetParameterString("in");
205 206 207 208 209 210 211 212

    otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
    otb::ogr::Layer layer = source->GetLayer(0);

    ListSampleType::Pointer input = ListSampleType::New();

    const int nbFeatures = GetSelectedItems("feat").size();
    input->SetMeasurementVectorSize(nbFeatures);
Julien Michel's avatar
Julien Michel committed
213
  
214 215 216
    otb::ogr::Layer::const_iterator it = layer.cbegin();
    otb::ogr::Layer::const_iterator itEnd = layer.cend();
    for( ; it!=itEnd ; ++it)
217
      {
218 219 220
      MeasurementType mv;
      mv.SetSize(nbFeatures);
      for(int idx=0; idx < nbFeatures; ++idx)
221
        {
Julien Michel's avatar
Julien Michel committed
222 223 224 225
        // Beware that itemIndex differs from ogr layer field index
        unsigned int itemIndex = GetSelectedItems("feat")[idx];
        std::string fieldName = GetChoiceNames( "feat" )[itemIndex];
        
226
        mv[idx] = static_cast<ValueType>((*it)[fieldName].GetValue<double>());
227
        }
228
      input->PushBack(mv);
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
      }

    // Statistics for shift/scale
    MeasurementType meanMeasurementVector;
    MeasurementType stddevMeasurementVector;
    if (HasValue("instat") && IsParameterEnabled("instat"))
      {
      StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
      std::string XMLfile = GetParameterString("instat");
      statisticsReader->SetFileName(XMLfile);
      meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
      }
    else
      {
      meanMeasurementVector.SetSize(nbFeatures);
      meanMeasurementVector.Fill(0.);
      stddevMeasurementVector.SetSize(nbFeatures);
      stddevMeasurementVector.Fill(1.);
      }

    ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
    trainingShiftScaleFilter->SetInput(input);
    trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
    trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
    trainingShiftScaleFilter->Update();
    otbAppLogINFO("mean used: " << meanMeasurementVector);
    otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);

    otbAppLogINFO("Loading model");
    m_Model = MachineLearningModelFactoryType::CreateMachineLearningModel(GetParameterString("model"),
                                                MachineLearningModelFactoryType::ReadMode);

    if (m_Model.IsNull())
      {
      otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
      }

    m_Model->Load(GetParameterString("model"));
    otbAppLogINFO("Model loaded");

270
    ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput();
271

272
    ConfidenceListSampleType::Pointer quality;
273 274 275 276 277 278 279 280 281

    bool computeConfidenceMap(IsParameterEnabled("confmap") && m_Model->HasConfidenceIndex() 
                              && !m_Model->GetRegressionMode());

    if (!m_Model->HasConfidenceIndex() && IsParameterEnabled("confmap"))
      {
      otbAppLogWARNING("Confidence map requested but the classifier doesn't support it!");
      }

282
    LabelListSampleType::Pointer target;
283 284 285 286 287 288 289 290 291 292 293
    if (computeConfidenceMap)
      {
      quality = ConfidenceListSampleType::New();
      target = m_Model->PredictBatch(listSample, quality);
      }
      else
      {
      target = m_Model->PredictBatch(listSample);
      }

    ogr::DataSource::Pointer output;
294 295
    ogr::DataSource::Pointer buffer = ogr::DataSource::New();
    bool updateMode = false;
296 297
    if (IsParameterEnabled("out") && HasValue("out"))
      {
298
      // Create new OGRDataSource
299
      output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
300 301 302 303 304 305 306 307 308 309 310
      otb::ogr::Layer newLayer = output->CreateLayer(
        GetParameterString("out"),
        const_cast<OGRSpatialReference*>(layer.GetSpatialRef()),
        layer.GetGeomType());
      // Copy existing fields
      OGRFeatureDefn &inLayerDefn = layer.GetLayerDefn();
      for (int k=0 ; k<inLayerDefn.GetFieldCount() ; k++)
        {
        OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
        newLayer.CreateField(fieldDefn);
        }
311 312 313 314
      }
    else
      {
      // Update mode
315
      updateMode = true;
316
      otbAppLogINFO("Update input vector data.");
317 318 319 320 321 322
      // fill temporary buffer for the transfer
      otb::ogr::Layer inputLayer = layer;
      layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
      // close input data source
      source->Clear();
      // Re-open input data source in update mode
323 324 325
      output = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Update_LayerUpdate);
      }

326 327
    otb::ogr::Layer outLayer = output->GetLayer(0);

328 329 330 331 332 333
    OGRErr errStart = outLayer.ogr().StartTransaction();
    if (errStart != OGRERR_NONE)
      {
      itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
      }

334 335 336 337
    // Add the field of prediction in the output layer if field not exist
    OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
    int idx = layerDefn.GetFieldIndex(GetParameterString("cfield").c_str());
    if (idx >= 0)
338
      {
339 340
      if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger)
        itkExceptionMacro("Field name "<< GetParameterString("cfield") << " already exists with a different type!");
341 342
      }
    else
343 344 345
      {
      OGRFieldDefn predictedField(GetParameterString("cfield").c_str(), OFTInteger);
      ogr::FieldDefn predictedFieldDef(predictedField);
346
      outLayer.CreateField(predictedFieldDef);
347
      }
348 349

    // Add confidence field in the output layer
350
    std::string confFieldName("confidence");
351 352
    if (computeConfidenceMap)
      {
353 354 355 356 357 358
      idx = layerDefn.GetFieldIndex(confFieldName.c_str());
      if (idx >= 0)
        {
        if (layerDefn.GetFieldDefn(idx)->GetType() != OFTReal)
          itkExceptionMacro("Field name "<< confFieldName << " already exists with a different type!");
        }
359
      else
360 361 362 363 364
        {
        OGRFieldDefn confidenceField(confFieldName.c_str(), OFTReal);
        confidenceField.SetWidth(confidenceField.GetWidth());
        confidenceField.SetPrecision(confidenceField.GetPrecision());
        ogr::FieldDefn confFieldDefn(confidenceField);
365
        outLayer.CreateField(confFieldDefn);
366
        }
367 368
      }

369
    // Fill output layer
370 371
    unsigned int count=0;
    std::string classfieldname = GetParameterString("cfield");
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
    it = layer.cbegin();
    itEnd = layer.cend();
    for( ; it!=itEnd ; ++it, ++count)
      {
      ogr::Feature dstFeature(outLayer.GetLayerDefn());
      dstFeature.SetFrom( *it , TRUE);
      dstFeature.SetFID(it->GetFID());
      dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
      if (computeConfidenceMap)
        dstFeature[confFieldName].SetValue<double>(quality->GetMeasurementVector(count)[0]);
      if (updateMode)
        {
        outLayer.SetFeature(dstFeature);
        }
      else
        {
        outLayer.CreateFeature(dstFeature);
        }
      }
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414

    if(outLayer.ogr().TestCapability("Transactions"))
      {
      const OGRErr errCommitX = outLayer.ogr().CommitTransaction();
      if (errCommitX != OGRERR_NONE)
        {
        itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << outLayer.ogr().GetName() << ".");
        }
      }

    output->SyncToDisk();

    clock_t toc = clock();
    otbAppLogINFO( "Elapsed: "<< ((double)(toc - tic) / CLOCKS_PER_SEC)<<" seconds.");

  }

  ModelPointerType m_Model;
};

}
}

OTB_APPLICATION_EXPORT(otb::Wrapper::VectorClassifier)