diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx index a1d3c04a7e05c71c586cbf35968a30ffd8056a39..d17462be96ecd5b0a6d46e034716330915713beb 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx @@ -42,7 +42,11 @@ public: /** Standard macro */ itkTypeMacro(TrainImagesRegression, Superclass); - + + /** filters typedefs*/ + typedef otb::OGRDataToSamplePositionFilter<FloatVectorImageType, UInt8ImageType, otb::PeriodicSampler> PeriodicSamplerType; + typedef otb::SamplingRateCalculator::MapRateType MapRateType; + private: void DoInit() override { @@ -351,6 +355,57 @@ private: ExecuteInternal("extraction"); } } + + void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, + std::string sampleTrainFileName, + std::string sampleValidFileName, + std::string ratesTrainFileName) + { + // Split between training and validation + ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read ); + ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite ); + ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite ); + // read sampling rates from ratesTrainOutputs + SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); + rateCalculator->Read( ratesTrainFileName ); + // Compute sampling rates for train and valid + const otb::SamplingRateCalculator::MapRateType &inputRates = rateCalculator->GetRatesByClass(); + otb::SamplingRateCalculator::MapRateType trainRates; + otb::SamplingRateCalculator::MapRateType validRates; + otb::SamplingRateCalculator::TripletType tpt; + for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) + { + double vtr = GetParameterFloat( "sample.vtr" ); + unsigned long total = std::min( it->second.Required, it->second.Tot ); + unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); + unsigned long neededTrain = total - neededValid; + tpt.Tot = total; + tpt.Required = neededTrain; + tpt.Rate = ( 1.0 - vtr ); + trainRates[it->first] = tpt; + tpt.Tot = neededValid; + tpt.Required = neededValid; + tpt.Rate = 1.0; + validRates[it->first] = tpt; + } + + // Use an otb::OGRDataToSamplePositionFilter with 2 outputs + PeriodicSamplerType::SamplerParameterType param; + param.Offset = 0; + param.MaxJitter = 0; + PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); + splitter->SetInput( image ); + splitter->SetOGRData( source ); + splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); + splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); + splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] ); + splitter->SetLayerIndex( 0 ); + splitter->SetOriginFieldName( std::string( "" ) ); + splitter->SetSamplerParameters( param ); + splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); + AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); + splitter->Update(); + } /** Configure and execute TrainVectorClassifier. Note that many parameters of TrainVectorClassifier * are shared with the main application during initialization. */