Skip to content
Snippets Groups Projects
otbTrainDimensionalityReduction.cxx 4.78 KiB
Newer Older
  • Learn to ignore specific revisions
  • #include "otbWrapperApplication.h"
    #include "otbWrapperApplicationFactory.h"
    
    #include "otbOGRDataSourceWrapper.h"
    #include "otbOGRFeatureWrapper.h"
    
    #include "itkVariableLengthVector.h"
    
    #include "otbShiftScaleSampleListFilter.h"
    #include "otbStatisticsXMLFileReader.h"
    
    //#include "otbSharkUtils.h"
    
    #include <fstream> // write the model file
    
    
    #include "otbDimensionalityReductionModelFactory.h"
    #include "otbTrainDimensionalityReductionApplicationBase.h"
    
    class TrainDimensionalityReduction : public TrainDimensionalityReductionApplicationBase<float,float>
    
    {
    public:
      typedef TrainDimensionalityReduction Self;
    
      typedef TrainDimensionalityReductionApplicationBase<float, float> Superclass;
    
      typedef itk::SmartPointer<Self> Pointer;
      typedef itk::SmartPointer<const Self> ConstPointer;
    	
      itkNewMacro(Self);
      itkTypeMacro(TrainDimensionalityReduction, otb::Application);
    
    
      typedef Superclass::SampleType              SampleType;
      typedef Superclass::ListSampleType          ListSampleType;
      typedef Superclass::SampleImageType         SampleImageType;
    	  
      typedef float ValueType;
      typedef itk::VariableLengthVector<ValueType> MeasurementType;
    
      typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
    
      typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
    	
      typedef otb::DimensionalityReductionModelFactory<ValueType, ValueType>  ModelFactoryType;
    		
    private:
      void DoInit()
      {
        SetName("TrainDimensionalityReduction");
        SetDescription("Trainer for the dimensionality reduction algorithms used in the ImageDimensionalityReduction and VectorDimensionalityReduction applications.");
    	
        AddParameter(ParameterType_Group, "io", "Input and output data");
        SetParameterDescription("io", "This group of parameters allows setting input and output data.");
    
        AddParameter(ParameterType_InputVectorData, "io.vd", "Input Vector Data");
        SetParameterDescription("io.vd", "Input geometries used for training (note : all geometries from the layer will be used)");
    
        AddParameter(ParameterType_OutputFilename, "io.out", "Output model");
        SetParameterDescription("io.out", "Output file containing the model estimated (.txt format).");
    		
    	
        AddParameter(ParameterType_InputFilename, "io.stats", "Input XML image statistics file");
        MandatoryOff("io.stats");
        SetParameterDescription("io.stats", "XML file containing mean and variance of each feature.");
    
        AddParameter(ParameterType_StringList, "feat", "Field names to be calculated."); //
        SetParameterDescription("feat","List of field names in the input vector data used as features for training."); //
    		
        Superclass::DoInit();
    
        AddRAMParameter();
      }
    	
      void DoUpdateParameters()
      {
      }
    	
      void DoExecute()
      {	
    
        std::string shapefile = GetParameterString("io.vd");
    
        otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
        otb::ogr::Layer layer = source->GetLayer(0);
        ListSampleType::Pointer input = ListSampleType::New();
        const int nbFeatures = GetParameterStringList("feat").size();
    
        input->SetMeasurementVectorSize(nbFeatures);
        otb::ogr::Layer::const_iterator it = layer.cbegin();
        otb::ogr::Layer::const_iterator itEnd = layer.cend();
        for( ; it!=itEnd ; ++it)
          {
          MeasurementType mv;
          mv.SetSize(nbFeatures);
          for(int idx=0; idx < nbFeatures; ++idx)
            {
            mv[idx] = (*it)[GetParameterStringList("feat")[idx]].GetValue<double>();
            }
          input->PushBack(mv);
          }
    	
        MeasurementType meanMeasurementVector;
        MeasurementType stddevMeasurementVector;
    		
        if (HasValue("io.stats") && IsParameterEnabled("io.stats"))
          {
          StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
          std::string XMLfile = GetParameterString("io.stats");
          statisticsReader->SetFileName(XMLfile);
          meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
          stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
          }
        else
          {
          meanMeasurementVector.SetSize(nbFeatures);
          meanMeasurementVector.Fill(0.);
          stddevMeasurementVector.SetSize(nbFeatures);
          stddevMeasurementVector.Fill(1.);
          }
        
        ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
        trainingShiftScaleFilter->SetInput(input);
        trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
        trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
        trainingShiftScaleFilter->Update();
    
        ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput();
    	
    	
        this->Train(trainingListSample,GetParameterString("io.out"));
      }
    
    
    	
    
    
    };
    
    
    }
    }
    
    OTB_APPLICATION_EXPORT(otb::Wrapper::TrainDimensionalityReduction)