Skip to content
Snippets Groups Projects
Commit 078b7a5a authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

ENH: Update TrainImagesClustering and Classifier to use SampleSelection.

SampleSelection can now be used without input vector data.
parent aeb69fd5
No related branches found
No related tags found
No related merge requests found
......@@ -115,7 +115,7 @@ public:
ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
}
SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList);
// Then train the model with extracted samples
......
......@@ -34,6 +34,14 @@ public:
InitSampling();
InitClassification( false );
AddParameter( ParameterType_Float, "sample.percent", "Percentage of samples extract in images for "
"training and validation when only images are provided." );
SetParameterDescription( "sample.percent", "Percentage of samples extract in images for "
"training and validation when only images are provided. This parameter is disable when vector data are provided" );
SetDefaultParameterFloat( "sample.percent", 100.0 );
SetMinimumParameterFloatValue( "sample.percent", 0.0 );
SetMaximumParameterFloatValue( "sample.percent", 100.0 );
// Doc example parameter settings
SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
......@@ -51,8 +59,13 @@ public:
{
if( HasValue( "io.vd" ) )
{
MandatoryOff( "sample.percent" );
UpdatePolygonClassStatisticsParameters();
}
else
{
MandatoryOn( "sample.percent" );
}
}
void DoExecute() ITK_OVERRIDE
......@@ -60,12 +73,12 @@ public:
TrainFileNamesHandler fileNames;
FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
std::vector<std::string> vectorFileList = GetVectorFileList( GetParameterString( "io.out" ), fileNames );
std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
unsigned long nbInputs = imageList->Size();
if( nbInputs > vectorFileList.size() )
if( !vectorFileList.empty() && nbInputs > vectorFileList.size() )
{
otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
}
......@@ -90,28 +103,29 @@ public:
SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );
if( HasInputVector )
{
{
// Select and Extract samples for training with computed statistics and rates
ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs );
ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt );
SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS );
}
}
else
{
{
SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC );
}
}
// Select and Extract samples for validation with computed statistics and rates
// Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided
if( dedicatedValidation ) {
ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
if( dedicatedValidation )
{
ComputePolygonStatistics( imageList, validationVectorFileList, fileNames.polyStatValidOutputs );
ComputeSamplingRate( fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv );
}
SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
SelectAndExtractValidationSamples( fileNames, imageList, validationVectorFileList );
// Then train the model with extracted samples
TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs);
TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs );
// cleanup
if( IsParameterEnabled( "cleanup" ) )
......@@ -130,63 +144,6 @@ private :
UpdateInternalParameters( "polystat" );
}
/**
* Retrieve input vector data if provided otherwise generate a default vector shape file for each image.
* \param output vector file path
* \param fileNames
* \return list of input vector data file names
*/
std::vector<std::string> GetVectorFileList(std::string output, TrainFileNamesHandler &fileNames)
{
std::vector<std::string> vectorFileList;
bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
// Retrieve provided input vector data if available.
if( !HasInputVector )
{
FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
unsigned int nbInputs = static_cast<unsigned int>(imageList->Size());
for( unsigned int i = 0; i < nbInputs; ++i )
{
std::string name = output + "_vector_" + std::to_string( i ) + ".shp";
GenerateVectorDataFile( imageList->GetNthElement( i ), name );
fileNames.tmpVectorFileList.push_back( name );
}
vectorFileList = fileNames.tmpVectorFileList;
SetParameterStringList( "io.vd", vectorFileList, false );
UpdatePolygonClassStatisticsParameters();
GetInternalApplication( "polystat" )->SetParameterString( "field", "fid" );
}
else
{
vectorFileList = GetParameterStringList( "io.vd" );
}
return vectorFileList;
}
void GenerateVectorDataFile(const FloatVectorImageListType::ObjectPointerType &floatVectorImage, std::string name)
{
typedef otb::ImageToEnvelopeVectorDataFilter<FloatVectorImageType, VectorDataType> ImageToEnvelopeFilterType;
typedef ImageToEnvelopeFilterType::OutputVectorDataType OutputVectorData;
typedef otb::VectorDataFileWriter<OutputVectorData> VectorDataWriter;
ImageToEnvelopeFilterType::Pointer imageToEnvelopeVectorData = ImageToEnvelopeFilterType::New();
imageToEnvelopeVectorData->SetInput( floatVectorImage );
imageToEnvelopeVectorData->SetOutputProjectionRef( floatVectorImage->GetProjectionRef().c_str() );
OutputVectorData::Pointer vectorData = imageToEnvelopeVectorData->GetOutput();
// write temporary generated vector file to disk.
VectorDataWriter::Pointer vectorDataFileWriter = VectorDataWriter::New();
vectorDataFileWriter->SetInput( vectorData );
vectorDataFileWriter->SetFileName( name.c_str() );
vectorDataFileWriter->Write();
}
};
}
......
......@@ -327,7 +327,6 @@ protected:
std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy)
{
GetInternalApplication( "select" )->SetParameterInputImage( "in", image );
GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false );
GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false );
GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false );
......@@ -338,10 +337,12 @@ protected:
{
case GEOMETRIC:
GetInternalApplication( "select" )->SetParameterString( "sampler", "random", false );
GetInternalApplication( "select" )->SetParameterString( "strategy", "all", false );
GetInternalApplication( "select" )->SetParameterString( "strategy", "percent", false );
GetInternalApplication( "select" )->SetParameterFloat("strategy.percent.p", GetParameterFloat("sample.percent"), false);
break;
case CLASS:
default:
GetInternalApplication( "select" )->SetParameterString( "vec", vectorFileName, false );
GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false );
GetInternalApplication( "select" )->SetParameterString( "sampler", "periodic", false );
GetInternalApplication( "select" )->SetParameterInt( "sampler.periodic.jitter", 50 );
......@@ -365,19 +366,19 @@ protected:
for( unsigned int i = 0; i < imageList->Size(); ++i )
{
SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileNames[i], fileNames.sampleOutputs[i],
std::string vectorFileName = vectorFileNames.empty() ? "" : vectorFileNames[i];
SelectAndExtractSamples( imageList->GetNthElement( i ), vectorFileName, fileNames.sampleOutputs[i],
fileNames.polyStatTrainOutputs[i], fileNames.ratesTrainOutputs[i], strategy );
}
}
void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
const std::vector<std::string> &validationVectorFileList,
bool dedicatedValidation)
const std::vector<std::string> &validationVectorFileList)
{
// In dedicated validation mode the by class sampling strategy and statistics are used.
// Otherwise simply split training to validation samples corresponding to sample.vtr percentage.
if( dedicatedValidation )
if( !validationVectorFileList.empty() )
{
for( unsigned int i = 0; i < imageList->Size(); ++i )
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment