otbTrainOGRLayersClassifier.cxx 8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
/*=========================================================================
 Program:   ORFEO Toolbox
 Language:  C++
 Date:      $Date$
 Version:   $Revision$


 Copyright (c) Centre National d'Etudes Spatiales. All rights reserved.
 See OTBCopyright.txt for details.


 This software is distributed WITHOUT ANY WARRANTY; without even
 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
 PURPOSE.  See the above copyright notices for more information.

 =========================================================================*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

#include "otbOGRDataSourceWrapper.h"
#include "otbOGRFeatureWrapper.h"
#include "otbStatisticsXMLFileWriter.h"

#include "itkVariableLengthVector.h"
#include "otbStatisticsXMLFileReader.h"

#include "itkListSample.h"
#include "otbShiftScaleSampleListFilter.h"
29 30

#ifdef OTB_USE_LIBSVM
31
#include "otbLibSVMMachineLearningModel.h"
32
#endif
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53

#include <time.h>

namespace otb
{
namespace Wrapper
{
class TrainOGRLayersClassifier : public Application
{
public:
  typedef TrainOGRLayersClassifier Self;
  typedef Application Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
  itkNewMacro(Self)
;

  itkTypeMacro(TrainOGRLayersClassifier, otb::Application)
;

private:
54
  void DoInit() ITK_OVERRIDE
55 56 57 58 59 60 61 62 63 64 65
  {
    SetName("TrainOGRLayersClassifier");
    SetDescription("Train a SVM classifier based on labeled geometries and a list of features to consider.");

    SetDocName("TrainOGRLayersClassifier");
    SetDocLongDescription("This application trains a SVM classifier based on labeled geometries and a list of features to consider for classification.");
    SetDocLimitations("Experimental. For now only shapefiles are supported. Tuning of SVM classifier is not available.");
    SetDocAuthors("David Youssefi during internship at CNES");
    SetDocSeeAlso("OGRLayerClassifier,ComputeOGRLayersFeaturesStatistics");
    AddDocTag(Tags::Segmentation);
  
66
    AddParameter(ParameterType_InputVectorData, "inshp", "Name of the input shapefile");
67 68
    SetParameterDescription("inshp","Name of the input shapefile");

69 70
    AddParameter(ParameterType_InputFilename, "instats", "XML file containing mean and variance of each feature.");
    SetParameterDescription("instats", "XML file containing mean and variance of each feature.");
71 72 73 74 75 76 77 78 79 80 81

    AddParameter(ParameterType_OutputFilename, "outsvm", "Output model filename.");
    SetParameterDescription("outsvm", "Output model filename.");

    AddParameter(ParameterType_ListView,  "feat", "List of features to consider for classification.");
    SetParameterDescription("feat","List of features to consider for classification.");

    AddParameter(ParameterType_String,"cfield","Field containing the class id for supervision");
    SetParameterDescription("cfield","Field containing the class id for supervision. Only geometries with this field available will be taken into account.");
    SetParameterString("cfield","class");

82 83 84 85 86 87 88
    // Doc example parameter settings
    SetDocExampleParameterValue("inshp", "vectorData.shp");
    SetDocExampleParameterValue("instats", "meanVar.xml");
    SetDocExampleParameterValue("outsvm", "svmModel.svm");
    SetDocExampleParameterValue("feat", "perimeter");
    SetDocExampleParameterValue("cfield", "predicted");

89 90
  }

91
  void DoUpdateParameters() ITK_OVERRIDE
92 93
  {
    if ( HasValue("inshp") )
OTB Bot's avatar
STYLE  
OTB Bot committed
94
      {
95
      std::string shapefile = GetParameterString("inshp");
OTB Bot's avatar
STYLE  
OTB Bot committed
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

       otb::ogr::DataSource::Pointer ogrDS;
       otb::ogr::Layer layer(NULL, false);

       OGRSpatialReference oSRS("");
       std::vector<std::string> options;
       
       ogrDS = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
       std::string layername = itksys::SystemTools::GetFilenameName(shapefile);
       layername = layername.substr(0,layername.size()-4);
       layer = ogrDS->GetLayer(0);

       otb::ogr::Feature feature = layer.ogr().GetNextFeature();
       ClearChoices("feat");
       for(int iField=0; iField<feature.ogr().GetFieldCount(); iField++)
         {
           std::string key, item = feature.ogr().GetFieldDefnRef(iField)->GetNameRef();
           key = item;
           key.erase(std::remove(key.begin(), key.end(), ' '), key.end());
           std::transform(key.begin(), key.end(), key.begin(), tolower);
           key="feat."+key;
           AddChoice(key,item);
         }
119 120 121
      }
  }

122
  void DoExecute() ITK_OVERRIDE
123
  {
124
    #ifdef OTB_USE_LIBSVM 
125 126
    clock_t tic = clock();

127 128 129
    std::string shapefile = GetParameterString("inshp");
    std::string XMLfile = GetParameterString("instats");
    std::string modelfile = GetParameterString("outsvm");
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146

    typedef double ValueType;
    typedef itk::VariableLengthVector<ValueType> MeasurementType;
    typedef itk::Statistics::ListSample <MeasurementType> ListSampleType;
    typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
  
    typedef unsigned int LabelPixelType;
    typedef itk::FixedArray<LabelPixelType,1> LabelSampleType;
    typedef itk::Statistics::ListSample <LabelSampleType> LabelListSampleType;
  
    typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;

    StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
    statisticsReader->SetFileName(XMLfile);

    MeasurementType meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
    MeasurementType stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
147
   
OTB Bot's avatar
STYLE  
OTB Bot committed
148 149
    otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
    otb::ogr::Layer layer = source->GetLayer(0);
150 151 152 153 154 155
    bool goesOn = true;
    otb::ogr::Feature feature = layer.ogr().GetNextFeature();

    ListSampleType::Pointer input = ListSampleType::New();
    LabelListSampleType::Pointer target = LabelListSampleType::New();
    const int nbFeatures = GetSelectedItems("feat").size();
156 157 158

    input->SetMeasurementVectorSize(nbFeatures);
   
159 160
    if(feature.addr())
      while(goesOn)
OTB Bot's avatar
STYLE  
OTB Bot committed
161
       {
162
        if(feature.ogr().IsFieldSet(feature.ogr().GetFieldIndex(GetParameterString("cfield").c_str())))
OTB Bot's avatar
STYLE  
OTB Bot committed
163 164 165 166 167 168 169 170 171 172 173 174
           {
             MeasurementType mv; mv.SetSize(nbFeatures);
             
             for(int idx=0; idx < nbFeatures; ++idx)
              mv[idx] = feature.ogr().GetFieldAsDouble(GetSelectedItems("feat")[idx]);
             
             input->PushBack(mv);
             target->PushBack(feature.ogr().GetFieldAsInteger(GetParameterString("cfield").c_str()));
           }
         feature = layer.ogr().GetNextFeature();
         goesOn = feature.addr() != 0;
       }
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203

    ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
    trainingShiftScaleFilter->SetInput(input);
    trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
    trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
    trainingShiftScaleFilter->Update();
  
    ListSampleType::Pointer listSample;
    LabelListSampleType::Pointer labelListSample;

    listSample = trainingShiftScaleFilter->GetOutput();
    labelListSample = target;

    ListSampleType::Pointer trainingListSample = listSample;
    LabelListSampleType::Pointer trainingLabeledListSample = labelListSample;

    typedef otb::LibSVMMachineLearningModel<ValueType,LabelPixelType> LibSVMType;
    LibSVMType::Pointer libSVMClassifier = LibSVMType::New();
    libSVMClassifier->SetInputListSample(trainingListSample);
    libSVMClassifier->SetTargetListSample(trainingLabeledListSample);
    libSVMClassifier->SetParameterOptimization(true);
    libSVMClassifier->SetC(1.0);
    libSVMClassifier->SetKernelType(LINEAR);
    libSVMClassifier->Train();
    libSVMClassifier->Save(modelfile);

    clock_t toc = clock();

    otbAppLogINFO( "Elapsed: "<< ((double)(toc - tic) / CLOCKS_PER_SEC)<<" seconds.");
204
    
Christophe Palmann's avatar
Christophe Palmann committed
205
    #else
206 207 208
    otbAppLogFATAL("Module LIBSVM is not installed. You should consider turning OTB_USE_LIBSVM on during cmake configuration.");
    #endif
    
209 210 211 212 213 214 215 216 217
    }

};
}
}

OTB_APPLICATION_EXPORT(otb::Wrapper::TrainOGRLayersClassifier)