Skip to content
Snippets Groups Projects
otbTrainDimensionalityReduction.cxx 4.79 KiB
Newer Older
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"

#include "otbOGRDataSourceWrapper.h"
#include "otbOGRFeatureWrapper.h"

#include "itkVariableLengthVector.h"

#include "otbShiftScaleSampleListFilter.h"
#include "otbStatisticsXMLFileReader.h"

//#include "otbSharkUtils.h"

#include <fstream> // write the model file

#include "otbDimensionalityReductionModelFactory.h"
#include "otbTrainDimensionalityReductionApplicationBase.h"
class TrainDimensionalityReduction : public TrainDimensionalityReductionApplicationBase<float,float>
{
public:
  typedef TrainDimensionalityReduction Self;
  typedef TrainDimensionalityReductionApplicationBase<float, float> Superclass;
  typedef itk::SmartPointer<Self> Pointer;
  typedef itk::SmartPointer<const Self> ConstPointer;
	
  itkNewMacro(Self);
  itkTypeMacro(TrainDimensionalityReduction, otb::Application);


  typedef Superclass::SampleType              SampleType;
  typedef Superclass::ListSampleType          ListSampleType;
  typedef Superclass::SampleImageType         SampleImageType;
	  
  typedef float ValueType;
  typedef itk::VariableLengthVector<ValueType> MeasurementType;

  typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;

  typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
	
  typedef otb::DimensionalityReductionModelFactory<ValueType, ValueType>  ModelFactoryType;
		
private:
  void DoInit()
  {
    SetName("TrainDimensionalityReduction");
    SetDescription("Trainer for the dimensionality reduction algorithms used in the ImageDimensionalityReduction and VectorDimensionalityReduction applications.");
	
    AddParameter(ParameterType_Group, "io", "Input and output data");
    SetParameterDescription("io", "This group of parameters allows setting input and output data.");

    AddParameter(ParameterType_InputVectorData, "io.vd", "Input Vector Data");
    SetParameterDescription("io.vd", "Input geometries used for training (note : all geometries from the layer will be used)");

    AddParameter(ParameterType_OutputFilename, "io.out", "Output model");
    SetParameterDescription("io.out", "Output file containing the estimated model (.txt format).");
		
	
    AddParameter(ParameterType_InputFilename, "io.stats", "Input XML image statistics file");
    MandatoryOff("io.stats");
    SetParameterDescription("io.stats", "XML file containing mean and variance of each feature.");

    AddParameter(ParameterType_StringList, "feat", "Field names to be used for training."); //
    SetParameterDescription("feat","List of field names in the input vector data used as features for training."); //
		
    Superclass::DoInit();

    AddRAMParameter();
  }
	
  void DoUpdateParameters()
  {
  }
	
  void DoExecute()
  {	

    std::string shapefile = GetParameterString("io.vd");

    otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
    otb::ogr::Layer layer = source->GetLayer(0);
    ListSampleType::Pointer input = ListSampleType::New();
    const int nbFeatures = GetParameterStringList("feat").size();

    input->SetMeasurementVectorSize(nbFeatures);
    otb::ogr::Layer::const_iterator it = layer.cbegin();
    otb::ogr::Layer::const_iterator itEnd = layer.cend();
    for( ; it!=itEnd ; ++it)
      {
      MeasurementType mv;
      mv.SetSize(nbFeatures);
      for(int idx=0; idx < nbFeatures; ++idx)
        {
        mv[idx] = (*it)[GetParameterStringList("feat")[idx]].GetValue<double>();
        }
      input->PushBack(mv);
      }
	
    MeasurementType meanMeasurementVector;
    MeasurementType stddevMeasurementVector;
		
    if (HasValue("io.stats") && IsParameterEnabled("io.stats"))
      {
      StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
      std::string XMLfile = GetParameterString("io.stats");
      statisticsReader->SetFileName(XMLfile);
      meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
      stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
      }
    else
      {
      meanMeasurementVector.SetSize(nbFeatures);
      meanMeasurementVector.Fill(0.);
      stddevMeasurementVector.SetSize(nbFeatures);
      stddevMeasurementVector.Fill(1.);
      }
    
    ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
    trainingShiftScaleFilter->SetInput(input);
    trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
    trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
    trainingShiftScaleFilter->Update();

    ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput();
	
	
    this->Train(trainingListSample,GetParameterString("io.out"));
  }


	


};


}
}

OTB_APPLICATION_EXPORT(otb::Wrapper::TrainDimensionalityReduction)