Skip to content
Snippets Groups Projects
Commit 741c8a23 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

ENH: add method for splitting training and validation set

parent 81dc7185
No related branches found
No related tags found
No related merge requests found
......@@ -42,7 +42,11 @@ public:
/** Standard macro */
itkTypeMacro(TrainImagesRegression, Superclass);
/** filters typedefs*/
typedef otb::OGRDataToSamplePositionFilter<FloatVectorImageType, UInt8ImageType, otb::PeriodicSampler> PeriodicSamplerType;
typedef otb::SamplingRateCalculator::MapRateType MapRateType;
private:
void DoInit() override
{
......@@ -351,6 +355,57 @@ private:
ExecuteInternal("extraction");
}
}
void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName,
std::string sampleTrainFileName,
std::string sampleValidFileName,
std::string ratesTrainFileName)
{
// Split between training and validation
ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read );
ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite );
ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite );
// read sampling rates from ratesTrainOutputs
SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New();
rateCalculator->Read( ratesTrainFileName );
// Compute sampling rates for train and valid
const otb::SamplingRateCalculator::MapRateType &inputRates = rateCalculator->GetRatesByClass();
otb::SamplingRateCalculator::MapRateType trainRates;
otb::SamplingRateCalculator::MapRateType validRates;
otb::SamplingRateCalculator::TripletType tpt;
for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it )
{
double vtr = GetParameterFloat( "sample.vtr" );
unsigned long total = std::min( it->second.Required, it->second.Tot );
unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr );
unsigned long neededTrain = total - neededValid;
tpt.Tot = total;
tpt.Required = neededTrain;
tpt.Rate = ( 1.0 - vtr );
trainRates[it->first] = tpt;
tpt.Tot = neededValid;
tpt.Required = neededValid;
tpt.Rate = 1.0;
validRates[it->first] = tpt;
}
// Use an otb::OGRDataToSamplePositionFilter with 2 outputs
PeriodicSamplerType::SamplerParameterType param;
param.Offset = 0;
param.MaxJitter = 0;
PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New();
splitter->SetInput( image );
splitter->SetOGRData( source );
splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 );
splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 );
splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] );
splitter->SetLayerIndex( 0 );
splitter->SetOriginFieldName( std::string( "" ) );
splitter->SetSamplerParameters( param );
splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) );
AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." );
splitter->Update();
}
/** Configure and execute TrainVectorClassifier. Note that many parameters of TrainVectorClassifier
* are shared with the main application during initialization. */
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment