diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx index d17462be96ecd5b0a6d46e034716330915713beb..b88b73c90860c1f8d5a0f23cb916a23e6d54d3c8 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesRegression.cxx @@ -160,6 +160,10 @@ private: SetParameterDescription("sample.nv", "Number of validation samples."); MandatoryOff("sample.nv"); + AddParameter(ParameterType_Float, "sample.ratio", "Training and validation sample ratio"); + SetParameterDescription("sample.ratio", "Ratio between training and validation samples."); + SetDefaultParameterFloat("sample.ratio", 0.5); + ShareParameter( "rand", "select.rand" ); ShareParameter( "ram", "polystat.ram" ); @@ -356,55 +360,65 @@ private: } } - void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName, - std::string sampleTrainFileName, - std::string sampleValidFileName, - std::string ratesTrainFileName) + void SplitTrainingAndValidationSamples(const std::string & inputSampleFilePrefix) { - // Split between training and validation - ogr::DataSource::Pointer source = ogr::DataSource::New( sampleFileName, ogr::DataSource::Modes::Read ); - ogr::DataSource::Pointer destTrain = ogr::DataSource::New( sampleTrainFileName, ogr::DataSource::Modes::Overwrite ); - ogr::DataSource::Pointer destValid = ogr::DataSource::New( sampleValidFileName, ogr::DataSource::Modes::Overwrite ); - // read sampling rates from ratesTrainOutputs - SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); - rateCalculator->Read( ratesTrainFileName ); - // Compute sampling rates for train and valid - const otb::SamplingRateCalculator::MapRateType &inputRates = rateCalculator->GetRatesByClass(); - otb::SamplingRateCalculator::MapRateType trainRates; - otb::SamplingRateCalculator::MapRateType validRates; - otb::SamplingRateCalculator::TripletType tpt; - for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) + auto ImageList = GetParameterImageList("io.il"); + const auto& inputSampleFiles = m_FileHandler[inputSampleFilePrefix + "samples"]; + auto& trainSampleFiles = m_FileHandler["trainsamples"]; + auto& validSampleFiles = m_FileHandler["validsamples"]; + const auto& rateFiles = m_FileHandler[inputSampleFilePrefix + "rateFiles"]; + + for (unsigned int i = 0; i < ImageList->Size(); i++) + { + trainSampleFiles.push_back(GetParameterString("io.out") + "_trainsamples" + std::to_string(i) + ".shp"); + validSampleFiles.push_back(GetParameterString("io.out") + "_validsamples" + std::to_string(i) + ".shp"); + + // Split between training and validation + auto image = ImageList->GetNthElement(i); + ogr::DataSource::Pointer source = ogr::DataSource::New( inputSampleFiles[i], ogr::DataSource::Modes::Read ); + ogr::DataSource::Pointer destTrain = ogr::DataSource::New( trainSampleFiles[i], ogr::DataSource::Modes::Overwrite ); + ogr::DataSource::Pointer destValid = ogr::DataSource::New( validSampleFiles[i], ogr::DataSource::Modes::Overwrite ); + // read sampling rates from ratesTrainOutputs + SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); + rateCalculator->Read( rateFiles[i] ); + // Compute sampling rates for train and valid + const otb::SamplingRateCalculator::MapRateType &inputRates = rateCalculator->GetRatesByClass(); + otb::SamplingRateCalculator::MapRateType trainRates; + otb::SamplingRateCalculator::MapRateType validRates; + otb::SamplingRateCalculator::TripletType tpt; + for( MapRateType::const_iterator it = inputRates.begin(); it != inputRates.end(); ++it ) { - double vtr = GetParameterFloat( "sample.vtr" ); - unsigned long total = std::min( it->second.Required, it->second.Tot ); - unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); - unsigned long neededTrain = total - neededValid; - tpt.Tot = total; - tpt.Required = neededTrain; - tpt.Rate = ( 1.0 - vtr ); - trainRates[it->first] = tpt; - tpt.Tot = neededValid; - tpt.Required = neededValid; - tpt.Rate = 1.0; - validRates[it->first] = tpt; + double vtr = GetParameterFloat( "sample.ratio" ); + unsigned long total = std::min( it->second.Required, it->second.Tot ); + unsigned long neededValid = static_cast<unsigned long>(( double ) total * vtr ); + unsigned long neededTrain = total - neededValid; + tpt.Tot = total; + tpt.Required = neededTrain; + tpt.Rate = ( 1.0 - vtr ); + trainRates[it->first] = tpt; + tpt.Tot = neededValid; + tpt.Required = neededValid; + tpt.Rate = 1.0; + validRates[it->first] = tpt; } - // Use an otb::OGRDataToSamplePositionFilter with 2 outputs - PeriodicSamplerType::SamplerParameterType param; - param.Offset = 0; - param.MaxJitter = 0; - PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); - splitter->SetInput( image ); - splitter->SetOGRData( source ); - splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); - splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); - splitter->SetFieldName( this->GetParameterStringList( "sample.vfn" )[0] ); - splitter->SetLayerIndex( 0 ); - splitter->SetOriginFieldName( std::string( "" ) ); - splitter->SetSamplerParameters( param ); - splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); - AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); - splitter->Update(); + // Use an otb::OGRDataToSamplePositionFilter with 2 outputs + PeriodicSamplerType::SamplerParameterType param; + param.Offset = 0; + param.MaxJitter = 0; + PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); + splitter->SetInput( image ); + splitter->SetOGRData( source ); + splitter->SetOutputPositionContainerAndRates( destTrain, trainRates, 0 ); + splitter->SetOutputPositionContainerAndRates( destValid, validRates, 1 ); + splitter->SetFieldName( m_ClassFieldName ); + splitter->SetLayerIndex( 0 ); + splitter->SetOriginFieldName( std::string( "" ) ); + splitter->SetSamplerParameters( param ); + splitter->GetStreamer()->SetAutomaticTiledStreaming( static_cast<unsigned int>(this->GetParameterInt( "ram" )) ); + AddProcess( splitter->GetStreamer(), "Split samples between training and validation..." ); + splitter->Update(); + } } /** Configure and execute TrainVectorClassifier. Note that many parameters of TrainVectorClassifier @@ -419,13 +433,13 @@ private: { featureNames.push_back(m_FeaturePrefix + std::to_string(i)); } - + trainVectorRegression->SetParameterStringList("io.vd", trainSampleFileNameList); trainVectorRegression->UpdateParameters(); trainVectorRegression->SetParameterString("cfield", m_PredictionFieldName); trainVectorRegression->SetParameterStringList("feat", featureNames); - if (IsParameterEnabled("io.valid") && HasValue("io.valid")) + if ((IsParameterEnabled("io.valid") && HasValue("io.valid")) || GetParameterFloat( "sample.ratio" ) >0) { trainVectorRegression->SetParameterStringList("valid.vd", m_FileHandler["validsamples"]); } @@ -484,7 +498,7 @@ private: void DoExecute() override { SamplingParameters trainParams; - trainParams.filePrefix = "train"; + trainParams.filePrefix = "vd"; if (HasValue("sample.nt")) trainParams.numberOfSamples = GetParameterInt("sample.nt"); @@ -502,17 +516,29 @@ private: PerformSampling(trainParams); + // User validation data if (IsParameterEnabled("io.valid") && HasValue("io.valid")) { + m_FileHandler["trainsamples"] = m_FileHandler[trainParams.filePrefix + "samples"]; + SamplingParameters validParams; validParams.inputVectorList = GetParameterStringList("io.valid"); validParams.filePrefix = "valid"; if (HasValue("sample.nv")) validParams.numberOfSamples = GetParameterInt("sample.nv"); - + PerformSampling(validParams); } - + // Split train and validation data + else if (GetParameterFloat( "sample.ratio" ) >0) + { + SplitTrainingAndValidationSamples( trainParams.filePrefix); + } + else + { + m_FileHandler["trainsamples"] = m_FileHandler[trainParams.filePrefix + "samples"]; + } + otbAppLogINFO("Sampling Done."); TrainModel();