Skip to content
Snippets Groups Projects
Commit 90c05435 authored by Ludovic Hussonnois's avatar Ludovic Hussonnois
Browse files

ENH: Select strategy depending on provided Vector and do some refac.

parent 6f32ff41
No related branches found
No related tags found
No related merge requests found
...@@ -5,15 +5,130 @@ namespace otb ...@@ -5,15 +5,130 @@ namespace otb
namespace Wrapper namespace Wrapper
{ {
class TrainImagesClassifier : public TrainImagesBase<true> class TrainImagesClassifier : public TrainImagesBase
{ {
public: public:
typedef TrainImagesClassifier Self; typedef TrainImagesClassifier Self;
typedef TrainImagesBase<true> Superclass; typedef TrainImagesBase Superclass;
typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
itkNewMacro( Self ) itkNewMacro( Self )
itkTypeMacro( Self, Superclass ) itkTypeMacro( Self, Superclass )
void DoInit() ITK_OVERRIDE
{
SetName( "TrainImagesClassifier" );
SetDescription( "Train a classifier from multiple pairs of images and training vector data." );
// Documentation
SetDocName( "Train a classifier from multiple images" );
SetDocLongDescription(
"This application performs a classifier training from multiple pairs of input images and training vector data. "
"Samples are composed of pixel values in each band optionally centered and reduced using an XML statistics file produced by "
"the ComputeImagesStatistics application.\n The training vector data must contain polygons with a positive integer field "
"representing the class label. The name of this field can be set using the \"Class label field\" parameter. Training and validation "
"sample lists are built such that each class is equally represented in both lists. One parameter allows controlling the ratio "
"between the number of samples in training and validation sets. Two parameters allow managing the size of the training and "
"validation sets per class and per image.\n Several classifier parameters can be set depending on the chosen classifier. In the "
"validation process, the confusion matrix is organized the following way: rows = reference labels, columns = produced labels. "
"In the header of the optional confusion matrix output file, the validation (reference) and predicted (produced) class labels"
" are ordered according to the rows/columns of the confusion matrix.\n This application is based on LibSVM and OpenCV Machine Learning "
"(2.3.1 and later)." );
SetDocLimitations( "None" );
SetDocAuthors( "OTB-Team" );
SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
AddDocTag( Tags::Learning );
// Perform initialization
ClearApplications();
InitIO();
InitSampling();
InitClassification( true );
// Doc example parameter settings
SetDocExampleParameterValue("io.il", "QB_1_ortho.tif");
SetDocExampleParameterValue("io.vd", "VectorData_QB1.shp");
SetDocExampleParameterValue("io.imstat", "EstimateImageStatisticsQB1.xml");
SetDocExampleParameterValue("sample.mv", "100");
SetDocExampleParameterValue("sample.mt", "100");
SetDocExampleParameterValue("sample.vtr", "0.5");
SetDocExampleParameterValue("sample.vfn", "Class");
SetDocExampleParameterValue("classifier", "libsvm");
SetDocExampleParameterValue("classifier.libsvm.k", "linear");
SetDocExampleParameterValue("classifier.libsvm.c", "1");
SetDocExampleParameterValue("classifier.libsvm.opt", "false");
SetDocExampleParameterValue("io.out", "svmModelQB1.txt");
SetDocExampleParameterValue("io.confmatout", "svmConfusionMatrixQB1.csv");
}
void DoUpdateParameters() ITK_OVERRIDE
{
if( HasValue( "io.vd" ) )
{
std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false );
UpdateInternalParameters( "polystat" );
}
}
void DoExecute() ITK_OVERRIDE
{
TrainFileNamesHandler fileNames;
FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
unsigned long nbInputs = imageList->Size();
if( nbInputs > vectorFileList.size() )
{
otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
}
// check if validation vectors are given
std::vector<std::string> validationVectorFileList;
bool dedicatedValidation = false;
if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
{
validationVectorFileList = GetParameterStringList( "io.valid" );
if( nbInputs > validationVectorFileList.size() )
{
otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
}
dedicatedValidation = true;
}
fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );
// Compute final maximum sampling rates for both training and validation samples
SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );
// Select and Extract samples for training with computed statistics and rates
ComputePolygonStatistics(imageList, vectorFileList, fileNames.polyStatTrainOutputs);
ComputeSamplingRate(fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt);
SelectAndExtractTrainSamples(fileNames, imageList, vectorFileList, SamplingStrategy::CLASS);
// Select and Extract samples for validation with computed statistics and rates
// Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided
if( dedicatedValidation ) {
ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
}
SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
// Then train the model with extracted samples
TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs);
// cleanup
if( IsParameterEnabled( "cleanup" ) )
{
otbAppLogINFO( <<"Final clean-up ..." );
fileNames.clear();
}
}
}; };
} }
......
...@@ -5,15 +5,188 @@ namespace otb ...@@ -5,15 +5,188 @@ namespace otb
namespace Wrapper namespace Wrapper
{ {
class TrainImagesClustering : public TrainImagesBase<false> class TrainImagesClustering : public TrainImagesBase
{ {
public: public:
typedef TrainImagesClustering Self; typedef TrainImagesClustering Self;
typedef TrainImagesBase<false> Superclass; typedef TrainImagesBase Superclass;
typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer; typedef itk::SmartPointer<const Self> ConstPointer;
itkNewMacro( Self ) itkNewMacro( Self )
itkTypeMacro( Self, Superclass ) itkTypeMacro( Self, Superclass )
void DoInit() ITK_OVERRIDE
{
SetName( "TrainImagesClustering" );
SetDescription( "Train a classifier from multiple pairs of images and optional input training vector data." );
// Documentation
SetDocName( "Train a classifier from multiple images" );
SetDocLongDescription( "TODO" );
SetDocLimitations( "None" );
SetDocAuthors( "OTB-Team" );
SetDocSeeAlso( "OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html " );
AddDocTag( Tags::Learning );
ClearApplications();
InitIO();
InitSampling();
InitClassification( false );
// Doc example parameter settings
SetDocExampleParameterValue( "io.il", "QB_1_ortho.tif" );
SetDocExampleParameterValue( "io.vd", "VectorData_QB1.shp" );
SetDocExampleParameterValue( "io.imstat", "EstimateImageStatisticsQB1.xml" );
SetDocExampleParameterValue( "sample.mv", "100" );
SetDocExampleParameterValue( "sample.mt", "100" );
SetDocExampleParameterValue( "sample.vtr", "0.5" );
SetDocExampleParameterValue( "sample.vfn", "Class" );
SetDocExampleParameterValue( "classifier", "sharkkm" );
SetDocExampleParameterValue( "classifier.sharkkm.k", "2" );
SetDocExampleParameterValue( "io.out", "sharkKMModelQB1.txt" );
}
void DoUpdateParameters() ITK_OVERRIDE
{
if( HasValue( "io.vd" ) )
{
UpdatePolygonClassStatisticsParameters();
}
}
void DoExecute() ITK_OVERRIDE
{
TrainFileNamesHandler fileNames;
FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
std::vector<std::string> vectorFileList = GetVectorFileList( GetParameterString( "io.out" ), fileNames );
unsigned long nbInputs = imageList->Size();
if( nbInputs > vectorFileList.size() )
{
otbAppLogFATAL( "Missing input vector data files to match number of images (" << nbInputs << ")." );
}
// check if validation vectors are given
std::vector<std::string> validationVectorFileList;
bool dedicatedValidation = false;
if( IsParameterEnabled( "io.valid" ) && HasValue( "io.valid" ) )
{
validationVectorFileList = GetParameterStringList( "io.valid" );
if( nbInputs > validationVectorFileList.size() )
{
otbAppLogFATAL( "Missing validation vector data files to match number of images (" << nbInputs << ")." );
}
dedicatedValidation = true;
}
fileNames.CreateTemporaryFileNames( GetParameterString( "io.out" ), nbInputs, dedicatedValidation );
// Compute final maximum sampling rates for both training and validation samples
SamplingRates rates = ComputeFinalMaximumSamplingRates( dedicatedValidation );
if( HasInputVector )
{
// Select and Extract samples for training with computed statistics and rates
ComputePolygonStatistics( imageList, vectorFileList, fileNames.polyStatTrainOutputs );
ComputeSamplingRate( fileNames.polyStatTrainOutputs, fileNames.rateTrainOut, rates.fmt );
SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::CLASS );
}
else
{
SelectAndExtractTrainSamples( fileNames, imageList, vectorFileList, SamplingStrategy::GEOMETRIC );
}
// Select and Extract samples for validation with computed statistics and rates
// Validation samples could be empty if sample.vrt == 0 and if no dedicated validation are provided
if( dedicatedValidation ) {
ComputePolygonStatistics(imageList, validationVectorFileList, fileNames.polyStatValidOutputs);
ComputeSamplingRate(fileNames.polyStatValidOutputs, fileNames.rateValidOut, rates.fmv);
}
SelectAndExtractValidationSamples(fileNames, imageList, validationVectorFileList, dedicatedValidation);
// Then train the model with extracted samples
TrainModel( imageList, fileNames.sampleTrainOutputs, fileNames.sampleValidOutputs);
// cleanup
if( IsParameterEnabled( "cleanup" ) )
{
otbAppLogINFO( <<"Final clean-up ..." );
fileNames.clear();
}
}
private :
void UpdatePolygonClassStatisticsParameters()
{
std::vector<std::string> vectorFileList = GetParameterStringList( "io.vd" );
GetInternalApplication( "polystat" )->SetParameterString( "vec", vectorFileList[0], false );
UpdateInternalParameters( "polystat" );
}
/**
* Retrieve input vector data if provided otherwise generate a default vector shape file for each image.
* \param output vector file path
* \param fileNames
* \return list of input vector data file names
*/
std::vector<std::string> GetVectorFileList(std::string output, TrainFileNamesHandler &fileNames)
{
std::vector<std::string> vectorFileList;
bool HasInputVector = IsParameterEnabled( "io.vd" ) && HasValue( "io.vd" );
// Retrieve provided input vector data if available.
if( !HasInputVector )
{
FloatVectorImageListType *imageList = GetParameterImageList( "io.il" );
unsigned int nbInputs = static_cast<unsigned int>(imageList->Size());
for( unsigned int i = 0; i < nbInputs; ++i )
{
std::string name = output + "_vector_" + std::to_string( i ) + ".shp";
GenerateVectorDataFile( imageList->GetNthElement( i ), name );
fileNames.tmpVectorFileList.push_back( name );
}
vectorFileList = fileNames.tmpVectorFileList;
SetParameterStringList( "io.vd", vectorFileList, false );
UpdatePolygonClassStatisticsParameters();
GetInternalApplication( "polystat" )->SetParameterString( "field", "fid" );
}
else
{
vectorFileList = GetParameterStringList( "io.vd" );
}
return vectorFileList;
}
void GenerateVectorDataFile(const FloatVectorImageListType::ObjectPointerType &floatVectorImage, std::string name)
{
typedef otb::ImageToEnvelopeVectorDataFilter<FloatVectorImageType, VectorDataType> ImageToEnvelopeFilterType;
typedef ImageToEnvelopeFilterType::OutputVectorDataType OutputVectorData;
typedef otb::VectorDataFileWriter<OutputVectorData> VectorDataWriter;
ImageToEnvelopeFilterType::Pointer imageToEnvelopeVectorData = ImageToEnvelopeFilterType::New();
imageToEnvelopeVectorData->SetInput( floatVectorImage );
imageToEnvelopeVectorData->SetOutputProjectionRef( floatVectorImage->GetProjectionRef().c_str() );
OutputVectorData::Pointer vectorData = imageToEnvelopeVectorData->GetOutput();
// write temporary generated vector file to disk.
VectorDataWriter::Pointer vectorDataFileWriter = VectorDataWriter::New();
vectorDataFileWriter->SetInput( vectorData );
vectorDataFileWriter->SetFileName( name.c_str() );
vectorDataFileWriter->Write();
}
}; };
} }
......
...@@ -57,9 +57,8 @@ private: ...@@ -57,9 +57,8 @@ private:
// Doc example parameter settings // Doc example parameter settings
SetDocExampleParameterValue( "io.vd", "vectorData.shp" ); SetDocExampleParameterValue( "io.vd", "vectorData.shp" );
SetDocExampleParameterValue( "io.stats", "meanVar.xml" ); SetDocExampleParameterValue( "io.out", "kmeansModel.txt" );
SetDocExampleParameterValue( "io.out", "svmModel.svm" ); SetDocExampleParameterValue( "feat", "perimeter width area" );
SetDocExampleParameterValue( "feat", "perimeter area width" );
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment