Skip to content
Snippets Groups Projects
cbDimensionalityReductionTrainer.cxx 4.59 KiB
Newer Older
  • Learn to ignore specific revisions
  • Traizet Cedric's avatar
    Traizet Cedric committed
    #include "otbWrapperApplication.h"
    #include "otbWrapperApplicationFactory.h"
    
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    #include "otbOGRDataSourceWrapper.h"
    #include "otbOGRFeatureWrapper.h"
    
    
    #include "itkVariableLengthVector.h"
    
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    #include "otbShiftScaleSampleListFilter.h"
    #include "otbStatisticsXMLFileReader.h"
    
    
    
    #include <fstream> // write the model file
    
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    #include "DimensionalityReductionModelFactory.h"
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    namespace otb
    {
    namespace Wrapper
    {
    
    class CbDimensionalityReductionTrainer : public cbLearningApplicationBaseDR<float,float>
    
    Traizet Cedric's avatar
    Traizet Cedric committed
    {
    public:
    	typedef CbDimensionalityReductionTrainer Self;
    
    	typedef cbLearningApplicationBaseDR<float, float> Superclass;
    
    Traizet Cedric's avatar
    Traizet Cedric committed
    	typedef itk::SmartPointer<Self> Pointer;
    
    Traizet Cedric's avatar
    Traizet Cedric committed
    	itkNewMacro(Self);
    
    	itkTypeMacro(CbDimensionalityReductionTrainer, otb::Application);
    
    
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	typedef Superclass::SampleType              SampleType;
    	typedef Superclass::ListSampleType          ListSampleType;
    	typedef Superclass::SampleImageType         SampleImageType;
    	  
    	typedef double ValueType;
    
    	typedef itk::VariableLengthVector<ValueType> MeasurementType;
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    
    	typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader;
    
    	typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
    	
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	typedef otb::DimensionalityReductionModelFactory<ValueType, ValueType>  ModelFactoryType;
    
    Traizet Cedric's avatar
    Traizet Cedric committed
    private:
    	void DoInit()
    	{
    		SetName("CbDimensionalityReductionTrainer");
    		SetDescription("Trainer for the dimensionality reduction algorithms used in the cbDimensionalityReduction application.");
    
    		AddParameter(ParameterType_Group, "io", "Input and output data");
    		SetParameterDescription("io", "This group of parameters allows setting input and output data.");
    
    		AddParameter(ParameterType_InputVectorData, "io.vd", "Input Vector Data");
    		SetParameterDescription("io.vd", "Input geometries used for training (note : all geometries from the layer will be used)");
    
    		AddParameter(ParameterType_OutputFilename, "io.out", "Output model");
    		SetParameterDescription("io.out", "Output file containing the model estimated (.txt format).");
    
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		AddParameter(ParameterType_InputFilename, "io.stats", "Input XML image statistics file");
    		MandatoryOff("io.stats");
    		SetParameterDescription("io.stats", "XML file containing mean and variance of each feature.");
    
    
    		AddParameter(ParameterType_StringList, "feat", "Field names to be calculated."); //
    		SetParameterDescription("feat","List of field names in the input vector data used as features for training."); //
    		
    		Superclass::DoInit();
    
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		AddRAMParameter();
    
    Traizet Cedric's avatar
    Traizet Cedric committed
    	}
    	
    	void DoUpdateParameters()
    	{
    	}
    	
    	void DoExecute()
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    	{	
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		std::cout << "Appli Training!" << std::endl;
    
    		std::string shapefile = GetParameterString("io.vd");
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
    		otb::ogr::Layer layer = source->GetLayer(0);
    
    		ListSampleType::Pointer input = ListSampleType::New();
    		const int nbFeatures = GetParameterStringList("feat").size();
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    
    
    		input->SetMeasurementVectorSize(nbFeatures);
    		otb::ogr::Layer::const_iterator it = layer.cbegin();
    		otb::ogr::Layer::const_iterator itEnd = layer.cend();
    		for( ; it!=itEnd ; ++it)
    		{
    			MeasurementType mv;
    			mv.SetSize(nbFeatures);
    			for(int idx=0; idx < nbFeatures; ++idx)
    			{
    				mv[idx] = (*it)[GetParameterStringList("feat")[idx]].GetValue<double>();
    			}
    			input->PushBack(mv);
    		}
    		
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		MeasurementType meanMeasurementVector;
    		MeasurementType stddevMeasurementVector;
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		if (HasValue("io.stats") && IsParameterEnabled("io.stats"))
    		{
    			StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
    			std::string XMLfile = GetParameterString("io.stats");
    			statisticsReader->SetFileName(XMLfile);
    			meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
    			stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
    		}
    		else
    		{
    			meanMeasurementVector.SetSize(nbFeatures);
    			meanMeasurementVector.Fill(0.);
    			stddevMeasurementVector.SetSize(nbFeatures);
    			stddevMeasurementVector.Fill(1.);
    		}
        
    		ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
    		trainingShiftScaleFilter->SetInput(input);
    		trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
    		trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
    		trainingShiftScaleFilter->Update();
    
    		ListSampleType::Pointer trainingListSample= trainingShiftScaleFilter->GetOutput();
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    		this->Train(trainingListSample,GetParameterString("io.out"));
    
    Cédric Traizet's avatar
    Cédric Traizet committed
    
    }
    }
    
    OTB_APPLICATION_EXPORT(otb::Wrapper::CbDimensionalityReductionTrainer)