Skip to content
Snippets Groups Projects
Commit 916ac460 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

New application for vector dimensionality reductio

parent bc8b732a
No related branches found
No related tags found
No related merge requests found
...@@ -9,3 +9,9 @@ NAME CbDimensionalityReduction ...@@ -9,3 +9,9 @@ NAME CbDimensionalityReduction
SOURCES cbDimensionalityReduction.cxx SOURCES cbDimensionalityReduction.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES} ${OTBCommon_LIBRARIES} ${OTBITK_LIBRARIES} ${OTBBoost_LIBRARIES} ${OTBShark_LIBRARIES} LINK_LIBRARIES ${${otb-module}_LIBRARIES} ${OTBCommon_LIBRARIES} ${OTBITK_LIBRARIES} ${OTBBoost_LIBRARIES} ${OTBShark_LIBRARIES}
) )
OTB_CREATE_APPLICATION(
NAME CbDimensionalityReductionVector
SOURCES cbDimensionalityReductionVector.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES} ${OTBCommon_LIBRARIES} ${OTBITK_LIBRARIES} ${OTBBoost_LIBRARIES} ${OTBShark_LIBRARIES}
)
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "otbWrapperApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbOGRDataSourceWrapper.h"
#include "otbOGRFeatureWrapper.h"
#include "itkVariableLengthVector.h"
#include "otbStatisticsXMLFileReader.h"
#include "itkListSample.h"
#include "otbShiftScaleSampleListFilter.h"
#include "DimensionalityReductionModelFactory.h"
#include "DimensionalityReductionModel.h"
#include <time.h>
namespace otb
{
namespace Wrapper
{
/** Utility function to negate std::isalnum */
/*bool IsNotAlphaNum(char c)
{
return !std::isalnum(c);
}*/
class CbDimensionalityReductionVector : public Application
{
public:
/** Standard class typedefs. */
typedef CbDimensionalityReductionVector Self;
typedef Application Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkNewMacro(Self);
itkTypeMacro(Self, Application)
/** Filters typedef */
typedef double ValueType;
typedef DimensionalityReductionModel<ValueType,ValueType> DimensionalityReductionModelType;
typedef DimensionalityReductionModelFactory<ValueType, ValueType> DimensionalityReductionModelFactoryType;
typedef DimensionalityReductionModelType::Pointer ModelPointerType;
/** Statistics Filters typedef */
typedef itk::VariableLengthVector<ValueType> MeasurementType;
typedef otb::StatisticsXMLFileReader<MeasurementType> StatisticsReader;
typedef itk::VariableLengthVector<ValueType> InputSampleType;
typedef itk::Statistics::ListSample<InputSampleType> ListSampleType;
typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType;
~CbDimensionalityReductionVector() ITK_OVERRIDE
{
DimensionalityReductionModelFactoryType::CleanFactories();
}
private:
void DoInit() ITK_OVERRIDE
{
SetName("VectorDimensionalityReduction");
SetDescription("Performs dimensionality reduction of the input vector data according to a model file.");
SetDocName("Vector Dimensionality Reduction");
SetDocAuthors("OTB-Team");
SetDocLongDescription("This application performs a vector data dimensionality reduction based on a model file produced by the cbDimensionalityReductionTrainer application.");
SetDocSeeAlso("cbDimensionalityReductionTrainer");
AddDocTag(Tags::Learning);
AddParameter(ParameterType_InputVectorData, "in", "Name of the input vector data");
SetParameterDescription("in","The input vector data to reduce.");
AddParameter(ParameterType_InputFilename, "instat", "Statistics file");
SetParameterDescription("instat", "A XML file containing mean and standard deviation to center"
"and reduce samples before dimensionality reduction (produced by ComputeImagesStatistics application).");
MandatoryOff("instat");
AddParameter(ParameterType_InputFilename, "model", "Model file");
SetParameterDescription("model", "A model file (produced by cbDimensionalityReduction application,");
AddParameter(ParameterType_ListView, "feat", "Field names to be calculated."); //
SetParameterDescription("feat","List of field names in the input vector data used as features for training."); //
AddParameter(ParameterType_ListView, "feat_out", "Field names to be calculated."); //
SetParameterDescription("feat_out","List of field names in the input vector data used as features for training."); //
AddParameter(ParameterType_OutputFilename, "out", "Output vector data file containing the reduced vector");
SetParameterDescription("out","Output vector data file storing sample values (OGR format)."
"If not given, the input vector data file is updated.");
MandatoryOff("out");
// Doc example parameter settings
SetDocExampleParameterValue("in", "vectorData.shp");
SetDocExampleParameterValue("instat", "meanVar.xml");
SetDocExampleParameterValue("model", "model.txt");
SetDocExampleParameterValue("out", "vectorDataOut.shp");
SetDocExampleParameterValue("feat", "perimeter area width");
SetDocExampleParameterValue("feat_out", "perimeter area width");
//SetOfficialDocLink();
}
void DoUpdateParameters() ITK_OVERRIDE
{
/** I don't know what this does */
/*
if ( HasValue("in") )
{
std::string shapefile = GetParameterString("in");
otb::ogr::DataSource::Pointer ogrDS;
OGRSpatialReference oSRS("");
std::vector<std::string> options;
ogrDS = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
otb::ogr::Layer layer = ogrDS->GetLayer(0);
OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
ClearChoices("feat");
for(int iField=0; iField< layerDefn.GetFieldCount(); iField++)
{
std::string item = layerDefn.GetFieldDefn(iField)->GetNameRef();
std::string key(item);
key.erase( std::remove_if(key.begin(),key.end(),IsNotAlphaNum), key.end());
std::transform(key.begin(), key.end(), key.begin(), tolower);
OGRFieldType fieldType = layerDefn.GetFieldDefn(iField)->GetType();
if(fieldType == OFTInteger || ogr::version_proxy::IsOFTInteger64(fieldType) || fieldType == OFTReal)
{
std::string tmpKey="feat."+key;
AddChoice(tmpKey,item);
}
}
}*/
}
void DoExecute() ITK_OVERRIDE
{
clock_t tic = clock();
std::string shapefile = GetParameterString("in");
otb::ogr::DataSource::Pointer source = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Read);
otb::ogr::Layer layer = source->GetLayer(0);
ListSampleType::Pointer input = ListSampleType::New();
const int nbFeatures = GetSelectedItems("feat").size();
input->SetMeasurementVectorSize(nbFeatures);
otb::ogr::Layer::const_iterator it = layer.cbegin();
otb::ogr::Layer::const_iterator itEnd = layer.cend();
for( ; it!=itEnd ; ++it)
{
MeasurementType mv;
mv.SetSize(nbFeatures);
for(int idx=0; idx < nbFeatures; ++idx)
{
mv[idx] = (*it)[GetSelectedItems("feat")[idx]].GetValue<double>();
}
input->PushBack(mv);
}
/** Statistics for shift/scale */
MeasurementType meanMeasurementVector;
MeasurementType stddevMeasurementVector;
if (HasValue("instat") && IsParameterEnabled("instat"))
{
StatisticsReader::Pointer statisticsReader = StatisticsReader::New();
std::string XMLfile = GetParameterString("instat");
statisticsReader->SetFileName(XMLfile);
meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
}
else
{
meanMeasurementVector.SetSize(nbFeatures);
meanMeasurementVector.Fill(0.);
stddevMeasurementVector.SetSize(nbFeatures);
stddevMeasurementVector.Fill(1.);
}
ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New();
trainingShiftScaleFilter->SetInput(input);
trainingShiftScaleFilter->SetShifts(meanMeasurementVector);
trainingShiftScaleFilter->SetScales(stddevMeasurementVector);
trainingShiftScaleFilter->Update();
otbAppLogINFO("mean used: " << meanMeasurementVector);
otbAppLogINFO("standard deviation used: " << stddevMeasurementVector);
otbAppLogINFO("Loading model");
/** Read the model */
m_Model = DimensionalityReductionModelFactoryType::CreateDimensionalityReductionModel(GetParameterString("model"),
DimensionalityReductionModelFactoryType::ReadMode);
if (m_Model.IsNull())
{
otbAppLogFATAL(<< "Error when loading model " << GetParameterString("model") << " : unsupported model type");
}
m_Model->Load(GetParameterString("model"));
otbAppLogINFO("Model loaded");
/** Perform Dimensionality Reduction */
ListSampleType::Pointer listSample = trainingShiftScaleFilter->GetOutput();
ListSampleType::Pointer target;
target = m_Model->PredictBatch(listSample);
/** Create/Update Output Shape file */
ogr::DataSource::Pointer output;
ogr::DataSource::Pointer buffer = ogr::DataSource::New();
bool updateMode = false;
if (IsParameterEnabled("out") && HasValue("out"))
{
// Create new OGRDataSource
output = ogr::DataSource::New(GetParameterString("out"), ogr::DataSource::Modes::Overwrite);
otb::ogr::Layer newLayer = output->CreateLayer(
GetParameterString("out"),
const_cast<OGRSpatialReference*>(layer.GetSpatialRef()),
layer.GetGeomType());
// Copy existing fields
OGRFeatureDefn &inLayerDefn = layer.GetLayerDefn();
for (int k=0 ; k<inLayerDefn.GetFieldCount() ; k++)
{
OGRFieldDefn fieldDefn(inLayerDefn.GetFieldDefn(k));
newLayer.CreateField(fieldDefn);
}
}
else
{
// Update mode
updateMode = true;
otbAppLogINFO("Update input vector data.");
// fill temporary buffer for the transfer
otb::ogr::Layer inputLayer = layer;
layer = buffer->CopyLayer(inputLayer, std::string("Buffer"));
// close input data source
source->Clear();
// Re-open input data source in update mode
output = otb::ogr::DataSource::New(shapefile, otb::ogr::DataSource::Modes::Update_LayerUpdate);
}
otb::ogr::Layer outLayer = output->GetLayer(0);
OGRErr errStart = outLayer.ogr().StartTransaction();
if (errStart != OGRERR_NONE)
{
itkExceptionMacro(<< "Unable to start transaction for OGR layer " << outLayer.ogr().GetName() << ".");
}
// Add the field of prediction in the output layer if field not exist
OGRFeatureDefn &layerDefn = layer.GetLayerDefn();
int idx = layerDefn.GetFieldIndex(GetParameterString("feat_out").c_str());
if (idx >= 0)
{
if (layerDefn.GetFieldDefn(idx)->GetType() != OFTInteger)
itkExceptionMacro("Field name "<< GetParameterString("feat_out") << " already exists with a different type!");
}
else
{
OGRFieldDefn predictedField(GetParameterString("feat_out").c_str(), OFTInteger);
ogr::FieldDefn predictedFieldDef(predictedField);
outLayer.CreateField(predictedFieldDef);
}
// Fill output layer
unsigned int count=0;
std::string classfieldname = GetParameterString("feat_out");
it = layer.cbegin();
itEnd = layer.cend();
for( ; it!=itEnd ; ++it, ++count)
{
ogr::Feature dstFeature(outLayer.GetLayerDefn());
dstFeature.SetFrom( *it , TRUE);
dstFeature.SetFID(it->GetFID());
dstFeature[classfieldname].SetValue<int>(target->GetMeasurementVector(count)[0]);
if (updateMode)
{
outLayer.SetFeature(dstFeature);
}
else
{
outLayer.CreateFeature(dstFeature);
}
}
if(outLayer.ogr().TestCapability("Transactions"))
{
const OGRErr errCommitX = outLayer.ogr().CommitTransaction();
if (errCommitX != OGRERR_NONE)
{
itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << outLayer.ogr().GetName() << ".");
}
}
output->SyncToDisk();
clock_t toc = clock();
otbAppLogINFO( "Elapsed: "<< ((double)(toc - tic) / CLOCKS_PER_SEC)<<" seconds.");
}
ModelPointerType m_Model;
};
}
}
OTB_APPLICATION_EXPORT(otb::Wrapper::CbDimensionalityReductionVector)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment