Skip to content
Snippets Groups Projects
Commit 1f4d38b7 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: add sampling operation for the validation set

parent 4960df39
Branches
Tags
No related merge requests found
...@@ -44,19 +44,25 @@ public: ...@@ -44,19 +44,25 @@ public:
itkTypeMacro( TrainImagesRegression, Superclass ); itkTypeMacro( TrainImagesRegression, Superclass );
protected: protected:
struct SamplingParameters
{
std::vector<std::string> inputVectorList;
unsigned int numberOfSamples = 0;
std::string filePrefix = "";
};
void AddRegressionField() void AddRegressionField(const std::vector<std::string> & inputFileNames,
const std::string & filePrefix)
{ {
auto setFieldAppli = GetInternalApplication("setfield"); auto setFieldAppli = GetInternalApplication("setfield");
auto inputFileNames = GetParameterStringList( "io.vd" ); auto& outputFileNames = m_FileHandler[ filePrefix+"inputWithClassField" ];
auto& outputFileNames = m_FileHandler[ "inputWithClassField" ];
setFieldAppli->SetParameterString("fn", m_ClassFieldName); setFieldAppli->SetParameterString("fn", m_ClassFieldName);
setFieldAppli->SetParameterString("fv", "0"); setFieldAppli->SetParameterString("fv", "0");
for (unsigned int i =0; i < inputFileNames.size(); i++) for (unsigned int i =0; i < inputFileNames.size(); i++)
{ {
outputFileNames.push_back(GetParameterString("io.out")+"_setfield"+ std::to_string(i) +".shp"); outputFileNames.push_back(GetParameterString("io.out")+"_"+filePrefix+"Withfield"+ std::to_string(i) +".shp");
setFieldAppli->SetParameterString("in", inputFileNames[i]); setFieldAppli->SetParameterString("in", inputFileNames[i]);
setFieldAppli->SetParameterString("out", outputFileNames[i]); setFieldAppli->SetParameterString("out", outputFileNames[i]);
...@@ -65,17 +71,16 @@ protected: ...@@ -65,17 +71,16 @@ protected:
} }
} }
void ComputePolygonStatistics() void ComputePolygonStatistics(const std::string & filePrefix)
{ {
auto polygonClassAppli = GetInternalApplication("polystat"); auto polygonClassAppli = GetInternalApplication("polystat");
auto& input = m_FileHandler[ "inputWithClassField" ]; auto& input = m_FileHandler[ filePrefix+"inputWithClassField" ];
auto& output = m_FileHandler[ "statsFiles" ]; auto& output = m_FileHandler[ filePrefix+"statsFiles" ];
FloatVectorImageListType* inputImageList = GetParameterImageList( "io.il" ); FloatVectorImageListType* inputImageList = GetParameterImageList( "io.il" );
for (unsigned int i =0; i < input.size(); i++) for (unsigned int i =0; i < input.size(); i++)
{ {
output.push_back(GetParameterString("io.out")+"_polygonstat"+ std::to_string(i) +".xml"); output.push_back(GetParameterString("io.out")+"_"+filePrefix+"PolygonStats"+ std::to_string(i) +".xml");
polygonClassAppli->SetParameterInputImage( "in", inputImageList->GetNthElement(i) ); polygonClassAppli->SetParameterInputImage( "in", inputImageList->GetNthElement(i) );
polygonClassAppli->SetParameterString( "vec", input[i]); polygonClassAppli->SetParameterString( "vec", input[i]);
...@@ -89,16 +94,16 @@ protected: ...@@ -89,16 +94,16 @@ protected:
} }
void ComputeSamplingRate() void ComputeSamplingRate(const std::string & filePrefix, unsigned int numberOfSamples)
{ {
auto samplingRateAppli = GetInternalApplication("rates"); auto samplingRateAppli = GetInternalApplication("rates");
samplingRateAppli->SetParameterStringList( "il", m_FileHandler[ "statsFiles" ]); samplingRateAppli->SetParameterStringList( "il", m_FileHandler[ filePrefix+"statsFiles" ]);
std::string outputFileName = GetParameterString("io.out")+"_rates.csv"; std::string outputFileName = GetParameterString("io.out")+"_"+filePrefix+"rates.csv";
samplingRateAppli->SetParameterString("out", outputFileName); samplingRateAppli->SetParameterString("out", outputFileName);
if (HasValue("sample.nt")) if (numberOfSamples)
{ {
samplingRateAppli->SetParameterString("strategy", "constant"); samplingRateAppli->SetParameterString("strategy", "constant");
...@@ -112,27 +117,27 @@ protected: ...@@ -112,27 +117,27 @@ protected:
ExecuteInternal( "rates"); ExecuteInternal( "rates");
auto& rateFiles = m_FileHandler["rateFiles"]; auto& rateFiles = m_FileHandler[filePrefix+"rateFiles"];
for (unsigned int i = 0; i< m_FileHandler["statsFiles"].size(); i++) for (unsigned int i = 0; i< m_FileHandler[filePrefix+"statsFiles"].size(); i++)
{ {
rateFiles.push_back(GetParameterString("io.out")+"_rates_"+std::to_string(i+1)+".csv"); rateFiles.push_back(GetParameterString("io.out")+"_"+filePrefix+"rates_"+std::to_string(i+1)+".csv");
} }
} }
void SelectSamples() void SelectSamples(const std::string & filePrefix)
{ {
auto sampleSelection = GetInternalApplication("select"); auto sampleSelection = GetInternalApplication("select");
FloatVectorImageListType* inputImageList = GetParameterImageList( "io.il" ); FloatVectorImageListType* inputImageList = GetParameterImageList( "io.il" );
auto& inputVectorFiles = m_FileHandler[ "inputWithClassField" ]; auto& inputVectorFiles = m_FileHandler[ filePrefix+"inputWithClassField" ];
auto& outputVectorFiles = m_FileHandler[ "samples" ]; auto& outputVectorFiles = m_FileHandler[ filePrefix+"samples" ];
auto& rateFiles = m_FileHandler ["rateFiles"]; auto& rateFiles = m_FileHandler [filePrefix+"rateFiles"];
auto& statFiles = m_FileHandler ["statsFiles"]; auto& statFiles = m_FileHandler [filePrefix+"statsFiles"];
for (unsigned int i =0; i < inputVectorFiles.size(); i++) for (unsigned int i =0; i < inputVectorFiles.size(); i++)
{ {
outputVectorFiles.push_back(GetParameterString("io.out")+"_samples"+std::to_string(i)+".shp"); outputVectorFiles.push_back(GetParameterString("io.out")+"_"+filePrefix+"samples"+std::to_string(i)+".shp");
sampleSelection->SetParameterInputImage("in", inputImageList->GetNthElement(i)); sampleSelection->SetParameterInputImage("in", inputImageList->GetNthElement(i));
sampleSelection->SetParameterString("vec", inputVectorFiles[i]); sampleSelection->SetParameterString("vec", inputVectorFiles[i]);
sampleSelection->SetParameterString("instats", statFiles[i]); sampleSelection->SetParameterString("instats", statFiles[i]);
...@@ -147,13 +152,13 @@ protected: ...@@ -147,13 +152,13 @@ protected:
} }
} }
void ExtractSamples() void ExtractSamples(const std::string & filePrefix)
{ {
auto sampleExtraction = GetInternalApplication("extraction"); auto sampleExtraction = GetInternalApplication("extraction");
FloatVectorImageListType* featureImageList = GetParameterImageList( "io.il" ); FloatVectorImageListType* featureImageList = GetParameterImageList( "io.il" );
FloatVectorImageListType* predictorImageList = GetParameterImageList("io.ip"); FloatVectorImageListType* predictorImageList = GetParameterImageList("io.ip");
auto& vectorFiles = m_FileHandler[ "samples" ]; auto& vectorFiles = m_FileHandler[ filePrefix+"samples" ];
for (unsigned int i =0; i < vectorFiles.size(); i++) for (unsigned int i =0; i < vectorFiles.size(); i++)
{ {
...@@ -179,14 +184,14 @@ protected: ...@@ -179,14 +184,14 @@ protected:
{ {
auto trainVectorRegression = GetInternalApplication("training"); auto trainVectorRegression = GetInternalApplication("training");
auto& sampleFileNameList = m_FileHandler["samples"]; auto& trainSampleFileNameList = m_FileHandler["trainsamples"];
std::vector<std::string> featureNames; std::vector<std::string> featureNames;
for (unsigned int i = 0; i<sampleFileNameList.size(); i++) for (unsigned int i = 0; i< GetParameterStringList("io.vd").size(); i++)
{ {
featureNames.push_back(m_FeaturePrefix+std::to_string(i)); featureNames.push_back(m_FeaturePrefix+std::to_string(i));
} }
trainVectorRegression->SetParameterStringList("io.vd", sampleFileNameList); trainVectorRegression->SetParameterStringList("io.vd", trainSampleFileNameList);
trainVectorRegression->UpdateParameters(); trainVectorRegression->UpdateParameters();
trainVectorRegression->SetParameterString("cfield", m_PredictionFieldName); trainVectorRegression->SetParameterString("cfield", m_PredictionFieldName);
trainVectorRegression->SetParameterStringList("feat", featureNames); trainVectorRegression->SetParameterStringList("feat", featureNames);
...@@ -209,6 +214,10 @@ protected: ...@@ -209,6 +214,10 @@ protected:
AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" ); AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" );
SetParameterDescription( "io.vd", "A list of vector data to select the training samples." ); SetParameterDescription( "io.vd", "A list of vector data to select the training samples." );
MandatoryOn( "io.vd" ); MandatoryOn( "io.vd" );
AddParameter( ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List" );
SetParameterDescription( "io.valid", "A list of vector data to select the validation samples." );
MandatoryOff( "io.valid" );
} }
void InitSampling() void InitSampling()
...@@ -224,6 +233,10 @@ protected: ...@@ -224,6 +233,10 @@ protected:
SetParameterDescription( "sample.nt", "Number of training samples." ); SetParameterDescription( "sample.nt", "Number of training samples." );
MandatoryOff( "sample.nt" ); MandatoryOff( "sample.nt" );
AddParameter( ParameterType_Int, "sample.nv", "Number of validation samples" );
SetParameterDescription( "sample.nv", "Number of validation samples." );
MandatoryOff( "sample.nv" );
AddApplication( "SampleSelection", "select", "Sample selection" ); AddApplication( "SampleSelection", "select", "Sample selection" );
AddApplication( "SampleExtraction", "extraction", "Sample extraction" ); AddApplication( "SampleExtraction", "extraction", "Sample extraction" );
...@@ -302,25 +315,43 @@ private: ...@@ -302,25 +315,43 @@ private:
return res; return res;
} }
void PerformSampling(const SamplingParameters & params)
{
std::vector<std::string> vectorData = params.inputVectorList;
std::string filePrefix = params.filePrefix;
unsigned int numberOfSamples = params.numberOfSamples;
AddRegressionField(vectorData, filePrefix);
ComputePolygonStatistics(filePrefix);
ComputeSamplingRate(filePrefix, numberOfSamples);
SelectSamples(filePrefix);
ExtractSamples(filePrefix);
}
void DoExecute() override void DoExecute() override
{ {
//TODO validation set ?? SamplingParameters trainParams;
trainParams.inputVectorList = GetParameterStringList("io.vd");
AddRegressionField(); trainParams.filePrefix="train";
if (HasValue("sample.nt"))
trainParams.numberOfSamples = GetParameterInt("sample.nt");
std::cout << "Regression field added" << std::endl; PerformSampling(trainParams);
ComputePolygonStatistics(); if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
{
SamplingParameters validParams;
validParams.inputVectorList = GetParameterStringList("io.valid");
validParams.filePrefix="valid";
if (HasValue("sample.nv"))
validParams.numberOfSamples = GetParameterInt("sample.nv");
std::cout << "Polygon class statistic done" << std::endl; PerformSampling(validParams);
}
ComputeSamplingRate();
std::cout << "Sampling rate computation done" << std::endl;
SelectSamples();
ExtractSamples(); otbAppLogINFO( "Sampling Done." );
TrainModel(); TrainModel();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment