Skip to content
Snippets Groups Projects
Commit 90e161c2 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: handle CSV input

parent bc75ebb6
No related branches found
No related tags found
No related merge requests found
......@@ -37,6 +37,8 @@
// Balancing ListSample
#include "otbListSampleToBalancedListSampleFilter.h"
#include "itkMersenneTwisterRandomVariateGenerator.h"
// Elevation handler
#include "otbWrapperElevationParametersHandler.h"
......@@ -83,6 +85,8 @@ public:
typedef itk::PreOrderTreeIterator<VectorDataType::DataTreeType> TreeIteratorType;
typedef itk::Statistics::MersenneTwisterRandomVariateGenerator RandomGeneratorType;
protected:
TrainRegression()
{
......@@ -120,7 +124,7 @@ void DoInit()
AddParameter(ParameterType_Group, "io", "Input and output data");
SetParameterDescription("io", "This group of parameters allows to set input and output data.");
AddParameter(ParameterType_InputImageList, "io.il", "Input Image List");
SetParameterDescription("io.il", "A list of input images. Last band should contain the output value to predict.");
SetParameterDescription("io.il", "A list of input images. First (n-1) bands should contain the predictor. The last band should contain the output value to predict.");
AddParameter(ParameterType_InputFilename, "io.csv", "Input CSV file");
SetParameterDescription("io.csv","Input CSV file containing the predictors, and the output values in last column. Only used when no input image is given");
MandatoryOff("io.csv");
......@@ -129,7 +133,7 @@ void DoInit()
MandatoryOff("io.imstat");
SetParameterDescription("io.imstat",
"Input XML file containing the mean and the standard deviation of the input images.");
AddParameter(ParameterType_OutputFilename, "io.out", "Output model");
AddParameter(ParameterType_OutputFilename, "io.out", "Output regression model");
SetParameterDescription("io.out", "Output file containing the model estimated (.txt format).");
AddParameter(ParameterType_Float,"io.mse","Mean Square Error");
......@@ -164,7 +168,83 @@ void DoInit()
void DoUpdateParameters()
{
// Nothing to do here : all parameters are independent
if (HasValue("io.csv") && IsParameterEnabled("io.csv"))
{
MandatoryOff("io.il");
}
else
{
MandatoryOn("io.il");
}
}
void ParseCSVPredictors(std::string path, ListSampleType* outputList)
{
std::ifstream ifs;
ifs.open(path.c_str());
unsigned int nbCols = 0;
char sep = '\t';
std::istringstream iss;
SampleType elem;
while(!ifs.eof())
{
std::string line;
std::getline(ifs,line);
// filter current line
while (!line.empty() && (line[0] == ' ' || line[0] == '\t'))
{
line.erase(line.begin());
}
while (!line.empty() && ( *(line.end()-1) == ' ' || *(line.end()-1) == '\t' || *(line.end()-1) == '\r'))
{
line.erase(line.end()-1);
}
// Avoid commented lines or too short ones
if (!line.empty() && line[0] != '#')
{
std::vector<itksys::String> words = itksys::SystemTools::SplitString(line.c_str(),sep);
if (nbCols == 0)
{
// detect separator and feature size
if (words.size() < 2)
{
sep = ' ';
words = itksys::SystemTools::SplitString(line.c_str(),sep);
}
if (words.size() < 2)
{
sep = ';';
words = itksys::SystemTools::SplitString(line.c_str(),sep);
}
if (words.size() < 2)
{
sep = ',';
words = itksys::SystemTools::SplitString(line.c_str(),sep);
}
if (words.size() < 2)
{
otbAppLogFATAL(<< "Can't parse CSV file : less than 2 columns or unknonw separator (knowns ones are tab, space, comma and semi-colon)");
}
nbCols = words.size();
elem.SetSize(nbCols,false);
outputList->SetMeasurementVectorSize(nbCols);
}
else if (words.size() != nbCols )
{
otbAppLogWARNING(<< "Skip CSV line, wrong number of columns : got "<<words.size() << ", expected "<<nbCols);
continue;
}
elem.Fill(0.0);
for (unsigned int i=0 ; i<nbCols ; ++i)
{
iss.str(words[i]);
iss >> elem[i];
}
outputList->PushBack(elem);
}
}
ifs.close();
}
void DoExecute()
......@@ -254,6 +334,57 @@ void DoExecute()
concatenateTrainingSamples->AddInput(sampleGenerator->GetTrainingListSample());
concatenateValidationSamples->AddInput(sampleGenerator->GetValidationListSample());
}
// if no input image, try CSV
if (imageList->Size() == 0)
{
if (HasValue("io.csv") && IsParameterEnabled("io.csv"))
{
ListSampleType::Pointer csvListSample = ListSampleType::New();
this->ParseCSVPredictors(this->GetParameterString("io.csv"), csvListSample);
unsigned int totalCSVSize = csvListSample->Size();
if (totalCSVSize == 0)
{
otbAppLogFATAL("No input image and empty CSV file. Missing input data");
}
nbBands = csvListSample->GetMeasurementVectorSize();
nbFeatures = static_cast<unsigned int>(static_cast<int>(nbBands) - 1);
ListSampleType::Pointer csvTrainListSample = ListSampleType::New();
ListSampleType::Pointer csvValidListSample = ListSampleType::New();
csvTrainListSample->SetMeasurementVectorSize(nbBands);
csvValidListSample->SetMeasurementVectorSize(nbBands);
double ratio = this->GetParameterFloat("sample.vtr");
int trainSize = static_cast<int>(static_cast<double>(totalCSVSize)*(1.0-ratio));
int validSize = static_cast<int>(static_cast<double>(totalCSVSize)*(ratio));
if (trainSize > this->GetParameterInt("sample.mt"))
{
trainSize = this->GetParameterInt("sample.mt");
}
if (validSize > this->GetParameterInt("sample.mv"))
{
validSize = this->GetParameterInt("sample.mv");
}
double probaTrain = static_cast<double>(trainSize)/static_cast<double>(totalCSVSize);
double probaValid = static_cast<double>(validSize)/static_cast<double>(totalCSVSize);
RandomGeneratorType::Pointer randomGenerator = RandomGeneratorType::GetInstance();
for (unsigned int i=0; i<totalCSVSize; ++i)
{
double random = randomGenerator->GetUniformVariate(0.0, 1.0);
if (random < probaTrain)
{
csvTrainListSample->PushBack(csvListSample->GetMeasurementVector(i));
}
else if (random < probaTrain + probaValid)
{
csvValidListSample->PushBack(csvListSample->GetMeasurementVector(i));
}
}
concatenateTrainingSamples->AddInput(csvTrainListSample);
concatenateValidationSamples->AddInput(csvValidListSample);
}
}
// Update
concatenateTrainingSamples->Update();
concatenateValidationSamples->Update();
......@@ -376,7 +507,7 @@ void DoExecute()
this->Classify(performanceListSample, predictedList, GetParameterString("io.out"));
otbAppLogINFO("training performances");
otbAppLogINFO("Training performances");
double mse=0.0;
TargetListSampleType::MeasurementVectorType predictedElem;
for (TargetListSampleType::InstanceIdentifier i=0; i<performanceListSample->Size() ; ++i)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment