diff --git a/Modules/Applications/AppClassification/app/otbSampleSelection.cxx b/Modules/Applications/AppClassification/app/otbSampleSelection.cxx index ef81a64d5fd8564197b0ffe21741f0b3b17ff20c..5ff2926888f5a97b2e735fd8ac3e895dae114829 100644 --- a/Modules/Applications/AppClassification/app/otbSampleSelection.cxx +++ b/Modules/Applications/AppClassification/app/otbSampleSelection.cxx @@ -73,8 +73,6 @@ public: private: SampleSelection() { - m_Periodic = PeriodicSamplerType::New(); - m_Random = RandomSamplerType::New(); m_ReaderStat = XMLReaderType::New(); m_RateCalculator = RateCalculatorType::New(); } @@ -258,15 +256,9 @@ private: { // Clear state m_RateCalculator->ClearRates(); - m_Periodic->GetFilter()->ClearOutputs(); - m_Random->GetFilter()->ClearOutputs(); otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this,"elev"); - // Setup ram - m_Periodic->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram")); - m_Random->GetStreamer()->SetAutomaticAdaptativeStreaming(GetParameterInt("ram")); - // Get field name std::vector<int> selectedCFieldIdx = GetSelectedItems("field"); @@ -415,37 +407,41 @@ private: PeriodicSamplerType::SamplerParameterType param; param.Offset = 0; param.MaxJitter = this->GetParameterInt("sampler.periodic.jitter"); - - m_Periodic->SetInput(this->GetParameterImage("in")); - m_Periodic->SetOGRData(reprojVector); - m_Periodic->SetOutputPositionContainerAndRates(outputSamples, rates); - m_Periodic->SetFieldName(fieldName); - m_Periodic->SetLayerIndex(this->GetParameterInt("layer")); - m_Periodic->SetSamplerParameters(param); + param.MaxBufferSize = 100000000UL; + PeriodicSamplerType::Pointer periodicFilt = PeriodicSamplerType::New(); + periodicFilt->SetInput(this->GetParameterImage("in")); + periodicFilt->SetOGRData(reprojVector); + periodicFilt->SetOutputPositionContainerAndRates(outputSamples, rates); + periodicFilt->SetFieldName(fieldName); + periodicFilt->SetLayerIndex(this->GetParameterInt("layer")); + periodicFilt->SetSamplerParameters(param); if (IsParameterEnabled("mask") && HasValue("mask")) { - m_Periodic->SetMask(this->GetParameterImage<UInt8ImageType>("mask")); + periodicFilt->SetMask(this->GetParameterImage<UInt8ImageType>("mask")); } - m_Periodic->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram")); - AddProcess(m_Periodic->GetStreamer(),"Selecting positions with periodic sampler..."); - m_Periodic->Update(); + periodicFilt->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram")); + AddProcess(periodicFilt->GetStreamer(),"Selecting positions with periodic sampler..."); + periodicFilt->Update(); } break; // random case 1: { - m_Random->SetInput(this->GetParameterImage("in")); - m_Random->SetOGRData(reprojVector); - m_Random->SetOutputPositionContainerAndRates(outputSamples, rates); - m_Random->SetFieldName(fieldName); - m_Random->SetLayerIndex(this->GetParameterInt("layer")); + RandomSamplerType::Pointer randomFilt = RandomSamplerType::New(); + randomFilt->SetInput(this->GetParameterImage("in")); + randomFilt->SetOGRData(reprojVector); + randomFilt->SetOutputPositionContainerAndRates(outputSamples, rates); + randomFilt->SetFieldName(fieldName); + randomFilt->SetLayerIndex(this->GetParameterInt("layer")); if (IsParameterEnabled("mask") && HasValue("mask")) { - m_Random->SetMask(this->GetParameterImage<UInt8ImageType>("mask")); + randomFilt->SetMask(this->GetParameterImage<UInt8ImageType>("mask")); } - m_Random->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram")); - AddProcess(m_Random->GetStreamer(),"Selecting positions with random sampler..."); - m_Random->Update(); + randomFilt->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram")); + AddProcess(randomFilt->GetStreamer(),"Selecting positions with random sampler..."); + randomFilt->Update(); + + randomFilt = RandomSamplerType::New(); } break; default: @@ -455,10 +451,6 @@ private: } RateCalculatorType::Pointer m_RateCalculator; - - PeriodicSamplerType::Pointer m_Periodic; - RandomSamplerType::Pointer m_Random; - XMLReaderType::Pointer m_ReaderStat; }; diff --git a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx index 932454f0b5ebe89bd84c9367b6c582b9cd3c0970..201dbfd705bb51fe152d9118fd9d39d8f6e5aa4a 100644 --- a/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainImagesClassifier.cxx @@ -14,94 +14,68 @@ PURPOSE. See the above copyright notices for more information. =========================================================================*/ -#include "otbLearningApplicationBase.h" +#include "otbWrapperCompositeApplication.h" #include "otbWrapperApplicationFactory.h" -#include "otbListSampleGenerator.h" - -// Statistic XML Reader -#include "otbStatisticsXMLFileReader.h" - -// Validation -#include "otbConfusionMatrixCalculator.h" - -#include "itkTimeProbe.h" -#include "otbStandardFilterWatcher.h" - -// Normalize the samples -#include "otbShiftScaleSampleListFilter.h" - -// List sample concatenation -#include "otbConcatenateSampleListFilter.h" - -// Balancing ListSample -#include "otbListSampleToBalancedListSampleFilter.h" - -// VectorData projection filter - -// Extract a ROI of the vectordata -#include "otbVectorDataIntoImageProjectionFilter.h" - -// Elevation handler -#include "otbWrapperElevationParametersHandler.h" +#include "otbOGRDataToSamplePositionFilter.h" +#include "otbSamplingRateCalculator.h" namespace otb { namespace Wrapper { -class TrainImagesClassifier: public LearningApplicationBase<float,int> +class TrainImagesClassifier: public CompositeApplication { public: /** Standard class typedefs. */ typedef TrainImagesClassifier Self; - typedef LearningApplicationBase<float,int> Superclass; + typedef CompositeApplication Superclass; typedef itk::SmartPointer<Self> Pointer; typedef itk::SmartPointer<const Self> ConstPointer; /** Standard macro */ itkNewMacro(Self) - itkTypeMacro(TrainImagesClassifier, otb::Wrapper::LearningApplicationBase) - - typedef Superclass::SampleType SampleType; - typedef Superclass::ListSampleType ListSampleType; - typedef Superclass::TargetSampleType TargetSampleType; - typedef Superclass::TargetListSampleType TargetListSampleType; + itkTypeMacro(TrainImagesClassifier, otb::Wrapper::CompositeApplication) - typedef Superclass::SampleImageType SampleImageType; - typedef SampleImageType::PixelType PixelType; + /** filters typedefs*/ + typedef otb::OGRDataToSamplePositionFilter< + FloatVectorImageType, + UInt8ImageType, + otb::PeriodicSampler> PeriodicSamplerType; - // SampleList manipulation - typedef otb::ListSampleGenerator<SampleImageType, VectorDataType> ListSampleGeneratorType; - - typedef otb::Statistics::ConcatenateSampleListFilter<ListSampleType> ConcatenateListSampleFilterType; - typedef otb::Statistics::ConcatenateSampleListFilter<TargetListSampleType> ConcatenateLabelListSampleFilterType; - - // Statistic XML file Reader - typedef otb::StatisticsXMLFileReader<SampleType> StatisticsReader; - - // Enhance List Sample typedef otb::Statistics::ListSampleToBalancedListSampleFilter<ListSampleType, LabelListSampleType> BalancingListSampleFilterType; - typedef otb::Statistics::ShiftScaleSampleListFilter<ListSampleType, ListSampleType> ShiftScaleFilterType; - - // Estimate performance on validation sample - typedef otb::ConfusionMatrixCalculator<TargetListSampleType, TargetListSampleType> ConfusionMatrixCalculatorType; - typedef ConfusionMatrixCalculatorType::ConfusionMatrixType ConfusionMatrixType; - typedef ConfusionMatrixCalculatorType::MapOfIndicesType MapOfIndicesType; - typedef ConfusionMatrixCalculatorType::ClassLabelType ClassLabelType; - - // VectorData projection filter - typedef otb::VectorDataProjectionFilter<VectorDataType, VectorDataType> VectorDataProjectionFilterType; - - // Extract ROI - typedef otb::VectorDataIntoImageProjectionFilter<VectorDataType, SampleImageType> VectorDataReprojectionType; + typedef otb::SamplingRateCalculator::MapRateType MapRateType; protected: - //using Superclass::AddParameter; - //friend void InitSVMParams(TrainImagesClassifier & app); private: +bool RemoveFile(std::string &filePath) +{ + bool res = true; + if(itksys::SystemTools::FileExists(filePath.c_str())) + { + size_t posExt = filePath.rfind('.'); + if (posExt != std::string::npos && + filePath.compare(posExt,std::string::npos,".shp") == 0) + { + std::string shxPath = filePath.substr(0,posExt) + std::string(".shx"); + std::string dbfPath = filePath.substr(0,posExt) + std::string(".dbf"); + std::string prjPath = filePath.substr(0,posExt) + std::string(".prj"); + RemoveFile(shxPath); + RemoveFile(dbfPath); + RemoveFile(prjPath); + } + res = itksys::SystemTools::RemoveFile(filePath.c_str()); + if (!res) + { + otbAppLogINFO(<<"Unable to remove file "<<filePath); + } + } + return res; +} + void DoInit() ITK_OVERRIDE { SetName("TrainImagesClassifier"); @@ -126,62 +100,99 @@ void DoInit() ITK_OVERRIDE SetDocAuthors("OTB-Team"); SetDocSeeAlso("OpenCV documentation for machine learning http://docs.opencv.org/modules/ml/doc/ml.html "); + AddDocTag(Tags::Learning); + + ClearApplications(); + AddApplication("PolygonClassStatistics", "polystat","Polygon analysis"); + AddApplication("MultiImageSamplingRate", "rates", "Sampling rates"); + AddApplication("SampleSelection", "select", "Sample selection"); + AddApplication("SampleExtraction","extraction", "Sample extraction"); + AddApplication("TrainVectorClassifier", "training", "Model training"); + //Group IO AddParameter(ParameterType_Group, "io", "Input and output data"); SetParameterDescription("io", "This group of parameters allows setting input and output data."); + AddParameter(ParameterType_InputImageList, "io.il", "Input Image List"); SetParameterDescription("io.il", "A list of input images."); AddParameter(ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List"); SetParameterDescription("io.vd", "A list of vector data to select the training samples."); - AddParameter(ParameterType_InputFilename, "io.imstat", "Input XML image statistics file"); - MandatoryOff("io.imstat"); - SetParameterDescription("io.imstat", - "Input XML file containing the mean and the standard deviation of the input images."); - AddParameter(ParameterType_OutputFilename, "io.confmatout", "Output confusion matrix"); - SetParameterDescription("io.confmatout", "Output file containing the confusion matrix (.csv format)."); - MandatoryOff("io.confmatout"); - AddParameter(ParameterType_OutputFilename, "io.out", "Output model"); - SetParameterDescription("io.out", "Output file containing the model estimated (.txt format)."); - - // Elevation - ElevationParametersHandler::AddElevationParameters(this, "elev"); - - //Group Sample list + + AddParameter(ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List"); + SetParameterDescription("io.valid", "A list of vector data to select the training samples."); + MandatoryOff("io.valid"); + + ShareParameter("io.imstat","training.io.stats"); + ShareParameter("io.confmatout","training.io.confmatout"); + ShareParameter("io.out","training.io.out"); + + ShareParameter("elev","polystat.elev"); + + // Sampling settings AddParameter(ParameterType_Group, "sample", "Training and validation samples parameters"); SetParameterDescription("sample", - "This group of parameters allows you to set training and validation sample lists parameters."); - + "This group of parameters allows you to set training and validation sample lists parameters."); AddParameter(ParameterType_Int, "sample.mt", "Maximum training sample size per class"); - //MandatoryOff("mt"); SetDefaultParameterInt("sample.mt", 1000); - SetParameterDescription("sample.mt", "Maximum size per class (in pixels) of the training sample list (default = 1000) (no limit = -1). If equal to -1, then the maximal size of the available training sample list per class will be equal to the surface area of the smallest class multiplied by the training sample ratio."); + SetParameterDescription("sample.mt", "Maximum size per class (in pixels) of " + "the training sample list (default = 1000) (no limit = -1). If equal to -1," + " then the maximal size of the available training sample list per class " + "will be equal to the surface area of the smallest class multiplied by the" + " training sample ratio."); AddParameter(ParameterType_Int, "sample.mv", "Maximum validation sample size per class"); - // MandatoryOff("mv"); SetDefaultParameterInt("sample.mv", 1000); - SetParameterDescription("sample.mv", "Maximum size per class (in pixels) of the validation sample list (default = 1000) (no limit = -1). If equal to -1, then the maximal size of the available validation sample list per class will be equal to the surface area of the smallest class multiplied by the validation sample ratio."); - + SetParameterDescription("sample.mv", "Maximum size per class (in pixels) of " + "the validation sample list (default = 1000) (no limit = -1). If equal to -1," + " then the maximal size of the available validation sample list per class " + "will be equal to the surface area of the smallest class multiplied by the " + "validation sample ratio."); AddParameter(ParameterType_Int, "sample.bm", "Bound sample number by minimum"); SetDefaultParameterInt("sample.bm", 1); - SetParameterDescription("sample.bm", "Bound the number of samples for each class by the number of available samples by the smaller class. Proportions between training and validation are respected. Default is true (=1)."); - - - AddParameter(ParameterType_Empty, "sample.edg", "On edge pixel inclusion"); - SetParameterDescription("sample.edg", - "Takes pixels on polygon edge into consideration when building training and validation samples."); - MandatoryOff("sample.edg"); - + SetParameterDescription("sample.bm", "Bound the number of samples for each " + "class by the number of available samples by the smaller class. Proportions " + "between training and validation are respected. Default is true (=1)."); AddParameter(ParameterType_Float, "sample.vtr", "Training and validation sample ratio"); SetParameterDescription("sample.vtr", - "Ratio between training and validation samples (0.0 = all training, 1.0 = all validation) (default = 0.5)."); + "Ratio between training and validation samples (0.0 = all training, 1.0 = " + "all validation) (default = 0.5)."); SetParameterFloat("sample.vtr", 0.5); + SetMaximumParameterFloatValue("sample.vtr",1.0); + SetMinimumParameterFloatValue("sample.vtr",0.0); + + ShareParameter("sample.vfn","polystat.field"); - AddParameter(ParameterType_String, "sample.vfn", "Name of the discrimination field"); - SetParameterDescription("sample.vfn", "Name of the field used to discriminate class labels in the input vector data files."); - SetParameterString("sample.vfn", "Class"); + // hide sampling parameters + //ShareParameter("sample.strategy","rates.strategy"); + //ShareParameter("sample.mim","rates.mim"); - Superclass::DoInit(); + // Classifier settings + ShareParameter("classifier","training.classifier"); + + ShareParameter("rand","training.rand"); + + // Synchronization between applications + Connect("select.field", "polystat.field"); + Connect("select.layer", "polystat.layer"); + Connect("select.elev", "polystat.elev"); + + Connect("extraction.in", "select.in"); + Connect("extraction.vec", "select.out"); + Connect("extraction.field", "polystat.field"); + Connect("extraction.layer", "polystat.layer"); + + Connect("training.cfield", "polystat.field"); + + ShareParameter("ram","polystat.ram"); + Connect("select.ram", "polystat.ram"); + Connect("extraction.ram", "polystat.ram"); + + Connect("select.rand", "training.rand"); + + AddParameter(ParameterType_Empty,"cleanup","Temporary files cleaning"); + EnableParameter("cleanup"); + SetParameterDescription("cleanup","If activated, the application will try to clean all temporary files it created"); + MandatoryOff("cleanup"); - AddRANDParameter(); // Doc example parameter settings SetDocExampleParameterValue("io.il", "QB_1_ortho.tif"); SetDocExampleParameterValue("io.vd", "VectorData_QB1.shp"); @@ -201,396 +212,305 @@ void DoInit() ITK_OVERRIDE void DoUpdateParameters() ITK_OVERRIDE { - // Nothing to do here : all parameters are independent + if ( HasValue("io.vd") ) + { + std::vector<std::string> vectorFileList = GetParameterStringList("io.vd"); + GetInternalApplication("polystat")->SetParameterString("vec",vectorFileList[0]); + UpdateInternalParameters("polystat"); + } } -void LogConfusionMatrix(ConfusionMatrixCalculatorType* confMatCalc) +void DoExecute() ITK_OVERRIDE { - ConfusionMatrixCalculatorType::ConfusionMatrixType matrix = confMatCalc->GetConfusionMatrix(); - - // Compute minimal width - size_t minwidth = 0; - - for (unsigned int i = 0; i < matrix.Rows(); i++) + FloatVectorImageListType* imageList = GetParameterImageList("io.il"); + std::vector<std::string> vectorFileList = GetParameterStringList("io.vd"); + unsigned int nbInputs = imageList->Size(); + if (nbInputs > vectorFileList.size()) { - for (unsigned int j = 0; j < matrix.Cols(); j++) - { - std::ostringstream os; - os << matrix(i, j); - size_t size = os.str().size(); - - if (size > minwidth) - { - minwidth = size; - } - } + otbAppLogFATAL("Missing input vector data files to match number of images ("<<nbInputs<<")."); } - MapOfIndicesType mapOfIndices = confMatCalc->GetMapOfIndices(); - - MapOfIndicesType::const_iterator it = mapOfIndices.begin(); - MapOfIndicesType::const_iterator end = mapOfIndices.end(); - - for (; it != end; ++it) + // check if validation vectors are given + std::vector<std::string> validationVectorFileList; + bool dedicatedValidation = false; + if (IsParameterEnabled("io.valid") && HasValue("io.valid")) { - std::ostringstream os; - os << "[" << it->second << "]"; - - size_t size = os.str().size(); - if (size > minwidth) + dedicatedValidation = true; + validationVectorFileList = GetParameterStringList("io.valid"); + if (nbInputs > validationVectorFileList.size()) { - minwidth = size; + otbAppLogFATAL("Missing validation vector data files to match number of images ("<<nbInputs<<")."); } } - // Generate matrix string, with 'minwidth' as size specifier - std::ostringstream os; - - // Header line - for (size_t i = 0; i < minwidth; ++i) - os << " "; - os << " "; - - it = mapOfIndices.begin(); - end = mapOfIndices.end(); - for (; it != end; ++it) + // Prepare temporary file names + std::string outModel(GetParameterString("io.out")); + std::vector<std::string> polyStatTrainOutputs; + std::vector<std::string> polyStatValidOutputs; + std::vector<std::string> ratesTrainOutputs; + std::vector<std::string> ratesValidOutputs; + std::vector<std::string> sampleOutputs; + std::vector<std::string> sampleTrainOutputs; + std::vector<std::string> sampleValidOutputs; + std::string rateTrainOut; + if (dedicatedValidation) { - os << "[" << it->second << "]" << " "; + rateTrainOut = outModel + "_ratesTrain.csv"; } - - os << std::endl; - - // Each line of confusion matrix - for (unsigned int i = 0; i < matrix.Rows(); i++) + else { - ConfusionMatrixCalculatorType::ClassLabelType label = mapOfIndices[i]; - os << "[" << std::setw(minwidth - 2) << label << "]" << " "; - for (unsigned int j = 0; j < matrix.Cols(); j++) - { - os << std::setw(minwidth) << matrix(i, j) << " "; - } - os << std::endl; + rateTrainOut = outModel + "_rates.csv"; } - - otbAppLogINFO("Confusion matrix (rows = reference labels, columns = produced labels):\n" << os.str()); -} - -void DoExecute() ITK_OVERRIDE -{ - //Create training and validation for list samples and label list samples - ConcatenateLabelListSampleFilterType::Pointer concatenateTrainingLabels = - ConcatenateLabelListSampleFilterType::New(); - ConcatenateListSampleFilterType::Pointer concatenateTrainingSamples = ConcatenateListSampleFilterType::New(); - ConcatenateLabelListSampleFilterType::Pointer concatenateValidationLabels = - ConcatenateLabelListSampleFilterType::New(); - ConcatenateListSampleFilterType::Pointer concatenateValidationSamples = ConcatenateListSampleFilterType::New(); - - SampleType meanMeasurementVector; - SampleType stddevMeasurementVector; - - // Setup the DEM Handler - otb::Wrapper::ElevationParametersHandler::SetupDEMHandlerFromElevationParameters(this, "elev"); - - //-------------------------- - // Load measurements from images - unsigned int nbBands = 0; - //Iterate over all input images - - FloatVectorImageListType* imageList = GetParameterImageList("io.il"); - VectorDataListType* vectorDataList = GetParameterVectorDataList("io.vd"); - - vdreproj = VectorDataReprojectionType::New(); - - //Iterate over all input images - for (unsigned int imgIndex = 0; imgIndex < imageList->Size(); ++imgIndex) + std::string rateValidOut(outModel + "_ratesValid.csv"); + for (unsigned int i=0 ; i<nbInputs ; i++) { - std::ostringstream oss1, oss2; - oss1 << "Reproject polygons for image " << (imgIndex+1) << " ..."; - oss2 << "Extract samples from image " << (imgIndex+1) << " ..."; - - FloatVectorImageType::Pointer image = imageList->GetNthElement(imgIndex); - image->UpdateOutputInformation(); - - if (imgIndex == 0) + std::ostringstream oss; + oss <<i+1; + std::string strIndex(oss.str()); + if (dedicatedValidation) { - nbBands = image->GetNumberOfComponentsPerPixel(); + polyStatTrainOutputs.push_back(outModel + "_statsTrain_" + strIndex + ".xml"); + polyStatValidOutputs.push_back(outModel + "_statsValid_" + strIndex + ".xml"); + ratesTrainOutputs.push_back(outModel + "_ratesTrain_" + strIndex + ".csv"); + ratesValidOutputs.push_back(outModel + "_ratesValid_" + strIndex + ".csv"); + sampleOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp"); } - - // read the Vectordata - vdreproj->SetInputImage(image); - vdreproj->SetInput(vectorDataList->GetNthElement(imgIndex)); - vdreproj->SetUseOutputSpacingAndOriginFromImage(false); - - AddProcess(vdreproj, oss1.str()); - vdreproj->Update(); - - //Sample list generator - ListSampleGeneratorType::Pointer sampleGenerator = ListSampleGeneratorType::New(); - - sampleGenerator->SetInput(image); - sampleGenerator->SetInputVectorData(vdreproj->GetOutput()); - - sampleGenerator->SetClassKey(GetParameterString("sample.vfn")); - sampleGenerator->SetMaxTrainingSize(GetParameterInt("sample.mt")); - sampleGenerator->SetMaxValidationSize(GetParameterInt("sample.mv")); - sampleGenerator->SetValidationTrainingProportion(GetParameterFloat("sample.vtr")); - sampleGenerator->SetBoundByMin(GetParameterInt("sample.bm")!=0); - - // take pixel located on polygon edge into consideration - if (IsParameterEnabled("sample.edg")) + else { - sampleGenerator->SetPolygonEdgeInclusion(true); + polyStatTrainOutputs.push_back(outModel + "_stats_" + strIndex + ".xml"); + ratesTrainOutputs.push_back(outModel + "_rates_" + strIndex + ".csv"); + sampleOutputs.push_back(outModel + "_samples_" + strIndex + ".shp"); } - - AddProcess(sampleGenerator, oss2.str()); - sampleGenerator->Update(); - - TargetListSampleType::Pointer trainLabels = sampleGenerator->GetTrainingListLabel(); - ListSampleType::Pointer trainSamples = sampleGenerator->GetTrainingListSample(); - TargetListSampleType::Pointer validLabels = sampleGenerator->GetValidationListLabel(); - ListSampleType::Pointer validSamples = sampleGenerator->GetValidationListSample(); - - trainLabels->DisconnectPipeline(); - trainSamples->DisconnectPipeline(); - validLabels->DisconnectPipeline(); - validSamples->DisconnectPipeline(); - - //Concatenate training and validation samples from the image - concatenateTrainingLabels->AddInput(trainLabels); - concatenateTrainingSamples->AddInput(trainSamples); - concatenateValidationLabels->AddInput(validLabels); - concatenateValidationSamples->AddInput(validSamples); + sampleTrainOutputs.push_back(outModel + "_samplesTrain_" + strIndex + ".shp"); + sampleValidOutputs.push_back(outModel + "_samplesValid_" + strIndex + ".shp"); } - // Update - AddProcess(concatenateValidationLabels, "Concatenate samples ..."); - concatenateTrainingSamples->Update(); - concatenateTrainingLabels->Update(); - concatenateValidationSamples->Update(); - concatenateValidationLabels->Update(); - - if (concatenateTrainingSamples->GetOutput()->Size() == 0) + + // --------------------------------------------------------------------------- + // Polygons stats + for (unsigned int i=0 ; i<nbInputs ; i++) { - otbAppLogFATAL("No training samples, cannot perform SVM training."); + GetInternalApplication("polystat")->SetParameterInputImage("in",imageList->GetNthElement(i)); + GetInternalApplication("polystat")->SetParameterString("vec",vectorFileList[i]); + GetInternalApplication("polystat")->SetParameterString("out",polyStatTrainOutputs[i]); + ExecuteInternal("polystat"); + // analyse polygons given for validation + if (dedicatedValidation) + { + GetInternalApplication("polystat")->SetParameterString("vec",validationVectorFileList[i]); + GetInternalApplication("polystat")->SetParameterString("out",polyStatValidOutputs[i]); + ExecuteInternal("polystat"); + } } - if (concatenateValidationSamples->GetOutput()->Size() == 0) + // --------------------------------------------------------------------------- + // Compute sampling rates + GetInternalApplication("rates")->SetParameterString("mim","proportional"); + double vtr = GetParameterFloat("sample.vtr"); + long mt = GetParameterInt("sample.mt"); + long mv = GetParameterInt("sample.mv"); + // compute final maximum training and final maximum validation + // By default take all samples (-1 means all samples) + long fmt = -1; + long fmv = -1; + if (GetParameterInt("sample.bm") == 0) { - otbAppLogWARNING("No validation samples."); + if (dedicatedValidation) + { + // fmt and fmv will be used separately + fmt = mt; + fmv = mv; + if (mt > -1 && mv <= -1 && vtr < 0.99999) + { + fmv = static_cast<long>((double) mt * vtr / (1.0 - vtr)); + } + if (mt <= -1 && mv > -1 && vtr > 0.00001) + { + fmt = static_cast<long>((double) mv * (1.0 - vtr) / vtr); + } + } + else + { + // only fmt will be used for both training and validation samples + // So we try to compute the total number of samples given input + // parameters mt, mv and vtr. + if (mt > -1 && mv > -1) + { + fmt = mt + mv; + } + if (mt > -1 && mv <= -1 && vtr < 0.99999) + { + fmt = static_cast<long>((double) mt / (1.0 - vtr)); + } + if (mt <= -1 && mv > -1 && vtr > 0.00001) + { + fmt = static_cast<long>((double) mv / vtr); + } + } } - if (IsParameterEnabled("io.imstat")) + // Sampling rates for training + GetInternalApplication("rates")->SetParameterStringList("il",polyStatTrainOutputs); + GetInternalApplication("rates")->SetParameterString("out",rateTrainOut); + if (GetParameterInt("sample.bm") != 0) { - StatisticsReader::Pointer statisticsReader = StatisticsReader::New(); - statisticsReader->SetFileName(GetParameterString("io.imstat")); - meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); - stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); + GetInternalApplication("rates")->SetParameterString("strategy","smallest"); } else { - meanMeasurementVector.SetSize(nbBands); - meanMeasurementVector.Fill(0.); - stddevMeasurementVector.SetSize(nbBands); - stddevMeasurementVector.Fill(1.); + if (fmt > -1) + { + GetInternalApplication("rates")->SetParameterString("strategy","constant"); + GetInternalApplication("rates")->SetParameterInt("strategy.constant.nb",fmt); + } + else + { + GetInternalApplication("rates")->SetParameterString("strategy","all"); + } } - - // Shift scale the samples - ShiftScaleFilterType::Pointer trainingShiftScaleFilter = ShiftScaleFilterType::New(); - trainingShiftScaleFilter->SetInput(concatenateTrainingSamples->GetOutput()); - trainingShiftScaleFilter->SetShifts(meanMeasurementVector); - trainingShiftScaleFilter->SetScales(stddevMeasurementVector); - AddProcess(trainingShiftScaleFilter, "Normalize training samples ..."); - trainingShiftScaleFilter->Update(); - - ListSampleType::Pointer validationListSample=ListSampleType::New(); - - //Test if the validation test is empty - if ( concatenateValidationSamples->GetOutput()->Size() != 0 ) + ExecuteInternal("rates"); + // Sampling rates for validation + if (dedicatedValidation) { - ShiftScaleFilterType::Pointer validationShiftScaleFilter = ShiftScaleFilterType::New(); - validationShiftScaleFilter->SetInput(concatenateValidationSamples->GetOutput()); - validationShiftScaleFilter->SetShifts(meanMeasurementVector); - validationShiftScaleFilter->SetScales(stddevMeasurementVector); - AddProcess(validationShiftScaleFilter, "Normalize validation samples ..."); - validationShiftScaleFilter->Update(); - validationListSample = validationShiftScaleFilter->GetOutput(); + GetInternalApplication("rates")->SetParameterStringList("il",polyStatValidOutputs); + GetInternalApplication("rates")->SetParameterString("out",rateValidOut); + if (GetParameterInt("sample.bm") != 0) + { + GetInternalApplication("rates")->SetParameterString("strategy","smallest"); + } + else + { + if (fmv > -1) + { + GetInternalApplication("rates")->SetParameterString("strategy","constant"); + GetInternalApplication("rates")->SetParameterInt("strategy.constant.nb",fmv); + } + else + { + GetInternalApplication("rates")->SetParameterString("strategy","all"); + } + } + ExecuteInternal("rates"); } - ListSampleType::Pointer listSample; - TargetListSampleType::Pointer labelListSample; - //-------------------------- - // Balancing training sample (if needed) - // if (IsParameterEnabled("sample.b")) - // { - // // Balance the list sample. - // otbAppLogINFO("Number of training samples before balancing: " << concatenateTrainingSamples->GetOutput()->Size()) - // BalancingListSampleFilterType::Pointer balancingFilter = BalancingListSampleFilterType::New(); - // balancingFilter->SetInput(trainingShiftScaleFilter->GetOutput()); - // balancingFilter->SetInputLabel(concatenateTrainingLabels->GetOutput()); - // balancingFilter->SetBalancingFactor(GetParameterInt("sample.b")); - // balancingFilter->Update(); - // listSample = balancingFilter->GetOutput(); - // labelListSample = balancingFilter->GetOutputLabelSampleList(); - // otbAppLogINFO("Number of samples after balancing: " << balancingFilter->GetOutput()->Size()); - - // } - // else - // { - listSample = trainingShiftScaleFilter->GetOutput(); - labelListSample = concatenateTrainingLabels->GetOutput(); - otbAppLogINFO("Number of training samples: " << concatenateTrainingSamples->GetOutput()->Size()); - // } - //-------------------------- - // Split the data set into training/validation set - ListSampleType::Pointer trainingListSample = listSample; - TargetListSampleType::Pointer trainingLabeledListSample = labelListSample; - - TargetListSampleType::Pointer validationLabeledListSample = concatenateValidationLabels->GetOutput(); - otbAppLogINFO("Size of training set: " << trainingListSample->Size()); - otbAppLogINFO("Size of validation set: " << validationListSample->Size()); - otbAppLogINFO("Size of labeled training set: " << trainingLabeledListSample->Size()); - otbAppLogINFO("Size of labeled validation set: " << validationLabeledListSample->Size()); - - //-------------------------- - // Estimate model - //-------------------------- - this->Train(trainingListSample,trainingLabeledListSample,GetParameterString("io.out")); - - //-------------------------- - // Performances estimation - //-------------------------- - TargetListSampleType::Pointer predictedList = TargetListSampleType::New(); - ListSampleType::Pointer performanceListSample=ListSampleType::New(); - TargetListSampleType::Pointer performanceLabeledListSample=TargetListSampleType::New(); - - //Test the input validation set size - if(validationLabeledListSample->Size() != 0) + // --------------------------------------------------------------------------- + // Select & extract samples + GetInternalApplication("select")->SetParameterString("sampler", "periodic"); + GetInternalApplication("select")->SetParameterInt("sampler.periodic.jitter",50); + GetInternalApplication("select")->SetParameterString("strategy","byclass"); + GetInternalApplication("extraction")->SetParameterString("outfield", "prefix"); + GetInternalApplication("extraction")->SetParameterString("outfield.prefix.name","value_"); + for (unsigned int i=0 ; i<nbInputs ; i++) { - performanceListSample = validationListSample; - performanceLabeledListSample = validationLabeledListSample; - } - else - { - otbAppLogWARNING("The validation set is empty. The performance estimation is done using the input training set in this case."); - performanceListSample = trainingListSample; - performanceLabeledListSample = trainingLabeledListSample; - } - - this->Classify(performanceListSample, predictedList, GetParameterString("io.out")); - - ConfusionMatrixCalculatorType::Pointer confMatCalc = ConfusionMatrixCalculatorType::New(); - - otbAppLogINFO("Predicted list size : " << predictedList->Size()); - otbAppLogINFO("ValidationLabeledListSample size : " << performanceLabeledListSample->Size()); - confMatCalc->SetReferenceLabels(performanceLabeledListSample); - confMatCalc->SetProducedLabels(predictedList); - confMatCalc->Compute(); - - otbAppLogINFO("training performances"); - LogConfusionMatrix(confMatCalc); - - for (unsigned int itClasses = 0; itClasses < confMatCalc->GetNumberOfClasses(); itClasses++) + GetInternalApplication("select")->SetParameterInputImage("in",imageList->GetNthElement(i)); + GetInternalApplication("select")->SetParameterString("vec",vectorFileList[i]); + GetInternalApplication("select")->SetParameterString("out",sampleOutputs[i]); + GetInternalApplication("select")->SetParameterString("instats",polyStatTrainOutputs[i]); + GetInternalApplication("select")->SetParameterString("strategy.byclass.in",ratesTrainOutputs[i]); + // select sample positions + ExecuteInternal("select"); + // extract sample descriptors + ExecuteInternal("extraction"); + + if (dedicatedValidation) { - ConfusionMatrixCalculatorType::ClassLabelType classLabel = confMatCalc->GetMapOfIndices()[itClasses]; - - otbAppLogINFO("Precision of class [" << classLabel << "] vs all: " << confMatCalc->GetPrecisions()[itClasses]); - otbAppLogINFO("Recall of class [" << classLabel << "] vs all: " << confMatCalc->GetRecalls()[itClasses]); - otbAppLogINFO( - "F-score of class [" << classLabel << "] vs all: " << confMatCalc->GetFScores()[itClasses] << "\n"); + GetInternalApplication("select")->SetParameterString("vec",validationVectorFileList[i]); + GetInternalApplication("select")->SetParameterString("out",sampleValidOutputs[i]); + GetInternalApplication("select")->SetParameterString("instats",polyStatValidOutputs[i]); + GetInternalApplication("select")->SetParameterString("strategy.byclass.in",ratesValidOutputs[i]); + // select sample positions + ExecuteInternal("select"); + // extract sample descriptors + ExecuteInternal("extraction"); } - otbAppLogINFO("Global performance, Kappa index: " << confMatCalc->GetKappaIndex()); - - - if (this->HasValue("io.confmatout")) + else { - // Writing the confusion matrix in the output .CSV file - - MapOfIndicesType::iterator itMapOfIndicesValid, itMapOfIndicesPred; - ClassLabelType labelValid = 0; - - ConfusionMatrixType confusionMatrix = confMatCalc->GetConfusionMatrix(); - MapOfIndicesType mapOfIndicesValid = confMatCalc->GetMapOfIndices(); - - unsigned int nbClassesPred = mapOfIndicesValid.size(); - - ///////////////////////////////////////////// - // Filling the 2 headers for the output file - const std::string commentValidStr = "#Reference labels (rows):"; - const std::string commentPredStr = "#Produced labels (columns):"; - const char separatorChar = ','; - std::ostringstream ossHeaderValidLabels, ossHeaderPredLabels; - - // Filling ossHeaderValidLabels and ossHeaderPredLabels for the output file - ossHeaderValidLabels << commentValidStr; - ossHeaderPredLabels << commentPredStr; - - itMapOfIndicesValid = mapOfIndicesValid.begin(); - - while (itMapOfIndicesValid != mapOfIndicesValid.end()) + // Split between training and validation + ogr::DataSource::Pointer source = ogr::DataSource::New(sampleOutputs[i], ogr::DataSource::Modes::Read); + ogr::DataSource::Pointer destTrain = ogr::DataSource::New(sampleTrainOutputs[i], ogr::DataSource::Modes::Overwrite); + ogr::DataSource::Pointer destValid = ogr::DataSource::New(sampleValidOutputs[i], ogr::DataSource::Modes::Overwrite); + // read sampling rates from ratesTrainOutputs[i] + SamplingRateCalculator::Pointer rateCalculator = SamplingRateCalculator::New(); + rateCalculator->Read(ratesTrainOutputs[i]); + // Compute sampling rates for train and valid + const MapRateType &inputRates = rateCalculator->GetRatesByClass(); + MapRateType trainRates; + MapRateType validRates; + otb::SamplingRateCalculator::TripletType tpt; + for (MapRateType::const_iterator it = inputRates.begin() ; + it != inputRates.end() ; + ++it) { - // labels labelValid of mapOfIndicesValid are already sorted in otbConfusionMatrixCalculator - labelValid = itMapOfIndicesValid->second; - - otbAppLogINFO("mapOfIndicesValid[" << itMapOfIndicesValid->first << "] = " << labelValid); - - ossHeaderValidLabels << labelValid; - ossHeaderPredLabels << labelValid; - - ++itMapOfIndicesValid; - - if (itMapOfIndicesValid != mapOfIndicesValid.end()) - { - ossHeaderValidLabels << separatorChar; - ossHeaderPredLabels << separatorChar; - } - else - { - ossHeaderValidLabels << std::endl; - ossHeaderPredLabels << std::endl; - } + unsigned long total = std::min(it->second.Required,it->second.Tot ); + unsigned long neededValid = static_cast<unsigned long>((double) total * vtr ); + unsigned long neededTrain = total - neededValid; + tpt.Tot = total; + tpt.Required = neededTrain; + tpt.Rate = (1.0 - vtr); + trainRates[it->first] = tpt; + tpt.Tot = neededValid; + tpt.Required = neededValid; + tpt.Rate = 1.0; + validRates[it->first] = tpt; } - std::ofstream outFile; - outFile.open(this->GetParameterString("io.confmatout").c_str()); - outFile << std::fixed; - outFile.precision(10); - - ///////////////////////////////////// - // Writing the 2 headers - outFile << ossHeaderValidLabels.str(); - outFile << ossHeaderPredLabels.str(); - ///////////////////////////////////// - - unsigned int indexLabelValid = 0, indexLabelPred = 0; - - for (itMapOfIndicesValid = mapOfIndicesValid.begin(); itMapOfIndicesValid != mapOfIndicesValid.end(); ++itMapOfIndicesValid) - { - indexLabelPred = 0; - - for (itMapOfIndicesPred = mapOfIndicesValid.begin(); itMapOfIndicesPred != mapOfIndicesValid.end(); ++itMapOfIndicesPred) - { - // Writing the confusion matrix (sorted in otbConfusionMatrixCalculator) in the output file - outFile << confusionMatrix(indexLabelValid, indexLabelPred); - if (indexLabelPred < (nbClassesPred - 1)) - { - outFile << separatorChar; - } - else - { - outFile << std::endl; - } - ++indexLabelPred; - } - - ++indexLabelValid; - } + // Use an otb::OGRDataToSamplePositionFilter with 2 outputs + PeriodicSamplerType::SamplerParameterType param; + param.Offset = 0; + param.MaxJitter = 0; + PeriodicSamplerType::Pointer splitter = PeriodicSamplerType::New(); + splitter->SetInput(imageList->GetNthElement(i)); + splitter->SetOGRData(source); + splitter->SetOutputPositionContainerAndRates(destTrain, trainRates, 0); + splitter->SetOutputPositionContainerAndRates(destValid, validRates, 1); + splitter->SetFieldName(this->GetParameterStringList("sample.vfn")[0]); + splitter->SetLayerIndex(0); + splitter->SetOriginFieldName(std::string("")); + splitter->SetSamplerParameters(param); + splitter->GetStreamer()->SetAutomaticTiledStreaming(this->GetParameterInt("ram")); + AddProcess(splitter->GetStreamer(),"Split samples between training and validation..."); + splitter->Update(); + } + } - outFile.close(); - } // END if (this->HasValue("io.confmatout")) + // --------------------------------------------------------------------------- + // Train model + GetInternalApplication("training")->SetParameterStringList("io.vd",sampleTrainOutputs); + GetInternalApplication("training")->SetParameterStringList("valid.vd",sampleValidOutputs); + UpdateInternalParameters("training"); + // set field names + FloatVectorImageType::Pointer image = imageList->GetNthElement(0); + unsigned int nbBands = image->GetNumberOfComponentsPerPixel(); + std::vector<std::string> selectedNames; + for (unsigned int i=0 ; i<nbBands ; i++) + { + std::ostringstream oss; + oss << i; + selectedNames.push_back("value_"+oss.str()); + } + GetInternalApplication("training")->SetParameterStringList("feat",selectedNames); + ExecuteInternal("training"); - // TODO: implement hyperplane distance classifier and performance validation (cf. object detection) ? + // cleanup + if(IsParameterEnabled("cleanup")) + { + otbAppLogINFO(<<"Final clean-up ..."); + for(unsigned int i=0 ; i<polyStatTrainOutputs.size() ; i++) + RemoveFile(polyStatTrainOutputs[i]); + for(unsigned int i=0 ; i<polyStatValidOutputs.size() ; i++) + RemoveFile(polyStatValidOutputs[i]); + for(unsigned int i=0 ; i<ratesTrainOutputs.size() ; i++) + RemoveFile(ratesTrainOutputs[i]); + for(unsigned int i=0 ; i<ratesValidOutputs.size() ; i++) + RemoveFile(ratesValidOutputs[i]); + for(unsigned int i=0 ; i<sampleOutputs.size() ; i++) + RemoveFile(sampleOutputs[i]); + for(unsigned int i=0 ; i<sampleTrainOutputs.size() ; i++) + RemoveFile(sampleTrainOutputs[i]); + for(unsigned int i=0 ; i<sampleValidOutputs.size() ; i++) + RemoveFile(sampleValidOutputs[i]); + } } - VectorDataReprojectionType::Pointer vdreproj; }; } // end namespace Wrapper diff --git a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx index 7f92b72cc42b464394af17c548bc954710e85d83..d9aa85f98d059649753be036a1d039af54f1a5ed 100644 --- a/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx +++ b/Modules/Applications/AppClassification/app/otbTrainVectorClassifier.cxx @@ -410,9 +410,9 @@ void DoExecute() // Check all needed fields are present : // - check class field - cFieldIndex = feature.ogr().GetFieldIndex(GetParameterString("cfield").c_str()); + cFieldIndex = feature.ogr().GetFieldIndex(selectedCFieldName.c_str()); if (cFieldIndex < 0) - otbAppLogFATAL("The field name for class label ("<<GetParameterString("cfield") + otbAppLogFATAL("The field name for class label ("<<selectedCFieldName <<") has not been found in the input vector file! Choices are "<< availableFields); // - check feature fields for (unsigned int i=0 ; i<nbFeatures ; i++) diff --git a/Modules/Applications/AppClassification/test/CMakeLists.txt b/Modules/Applications/AppClassification/test/CMakeLists.txt index 76715f8d20cd4a35dd144fbdd71f471bb149bed4..aaa2f9cfe113a49279b2247b018b651f5ff05780 100644 --- a/Modules/Applications/AppClassification/test/CMakeLists.txt +++ b/Modules/Applications/AppClassification/test/CMakeLists.txt @@ -141,6 +141,7 @@ foreach(classifier ${classifierList}) -classifier ${lclassifier} ${${lclassifier}_parameters} -io.out ${TEMP}/${OUTMODELFILE} + -sample.vfn Class -rand 121212 VALID ${valid} @@ -162,6 +163,7 @@ foreach(classifier ${classifierList}) ${${lclassifier}_parameters} -io.out ${TEMP}/OutXML1_${OUTMODELFILE} -rand 121212 + -sample.vfn Class -outxml ${TEMP}/cl${classifier}_OutXML1.xml VALID ${valid} @@ -181,7 +183,7 @@ foreach(classifier ${classifierList}) -io.vd ${INPUTDATA}/Classification/VectorData_${${lclassifier}_input}QB1${vector_input_format} -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB1${stat_input_format} -io.out ${TEMP}/OutXML2_${OUTMODELFILE} - + -sample.vfn Class VALID ${valid} ) @@ -647,6 +649,7 @@ otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1_allOpt_InXML -sample.mv 100 -sample.mt 100 -sample.vtr 0.5 + -sample.vfn Class -classifier.libsvm.opt true -rand 121212 -io.out ${TEMP}/clsvmModelQB1_allOpt_InXML.svm @@ -659,6 +662,7 @@ otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1_OutXML OPTIONS -io.il ${INPUTDATA}/Classification/QB_1_ortho.tif -io.vd ${INPUTDATA}/Classification/VectorData_QB1.shp -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB1.xml + -sample.vfn Class -classifier libsvm -classifier.libsvm.opt true -io.out ${TEMP}/clsvmModelQB1_OutXML.svm @@ -677,6 +681,7 @@ otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1_OutXML ${INPUTDATA}/Classification/VectorData_QB2.shp ${INPUTDATA}/Classification/VectorData_QB3.shp -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB123.xml + -sample.vfn Class -classifier libsvm -classifier.libsvm.opt true -io.out ${TEMP}/clsvmModelQB123.svm @@ -693,6 +698,7 @@ otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1 -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB1.xml -classifier libsvm -classifier.libsvm.opt true + -sample.vfn Class -io.out ${TEMP}/clsvmModelQB1.svm -rand 121212 VALID ${ascii_comparison} @@ -708,6 +714,7 @@ otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1 ${INPUTDATA}/Classification/VectorData_QB5.shp ${INPUTDATA}/Classification/VectorData_QB6.shp -io.imstat ${INPUTDATA}/Classification/clImageStatisticsQB456.xml + -sample.vfn Class -classifier libsvm -classifier.libsvm.opt true -io.out ${TEMP}/clsvmModelQB456.svm @@ -727,6 +734,7 @@ otb_test_application(NAME apTvClTrainSVMImagesClassifierQB1_allOpt -sample.mv 100 -sample.mt 100 -sample.vtr 0.5 + -sample.vfn Class -classifier.libsvm.opt true -rand 121212 -io.out ${TEMP}/clsvmModelQB1_allOpt.svm diff --git a/Modules/Filtering/Statistics/include/otbPeriodicSampler.h b/Modules/Filtering/Statistics/include/otbPeriodicSampler.h index 3940bcd796b022d9d9ff7d9dbff42bacd9255fc4..3cc91743e735c0302436c078a842f265d17184f0 100644 --- a/Modules/Filtering/Statistics/include/otbPeriodicSampler.h +++ b/Modules/Filtering/Statistics/include/otbPeriodicSampler.h @@ -51,6 +51,9 @@ public: /** Maximum jitter to introduce (0 means no jitter) */ unsigned long MaxJitter; + + /** Maximum buffer size for internal jitter values */ + unsigned long MaxBufferSize; bool operator!=(const struct Parameter & param) const; } ParameterType; @@ -109,6 +112,9 @@ private: /** Internal current offset value * (either fixed, or reset each time a sample is taken)*/ double m_OffsetValue; + + /** jitter offsets computed up to MaxBufferSize */ + std::vector<double> m_JitterValues; }; } // namespace otb diff --git a/Modules/Filtering/Statistics/src/otbPeriodicSampler.cxx b/Modules/Filtering/Statistics/src/otbPeriodicSampler.cxx index ad4a6d88df1ed46ace6da5af59ac6f5d4b1894b5..1aba9367ffb3fe2c3e6ef2ab373c631a66fcbb82 100644 --- a/Modules/Filtering/Statistics/src/otbPeriodicSampler.cxx +++ b/Modules/Filtering/Statistics/src/otbPeriodicSampler.cxx @@ -27,7 +27,8 @@ bool PeriodicSampler::ParameterType::operator!=(const PeriodicSampler::ParameterType & param) const { return bool((Offset != param.Offset)|| - (MaxJitter != param.MaxJitter)); + (MaxJitter != param.MaxJitter) || + (MaxBufferSize != param.MaxBufferSize)); } void @@ -39,9 +40,20 @@ PeriodicSampler::Reset(void) if (m_JitterSize > 0.0) { // Using jitter : compute random offset value - m_OffsetValue = - itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance() + m_JitterValues.resize(std::min(this->GetNeededElements(), this->m_Parameters.MaxBufferSize)); + for (unsigned long i=0UL ; i<m_JitterValues.size() ; i++) + { + m_JitterValues[i] = itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance() ->GetUniformVariate(0.0,m_JitterSize); + } + if (m_JitterValues.empty()) + { + m_OffsetValue = 0.0; + } + else + { + m_OffsetValue = m_JitterValues[0]; + } } else { @@ -71,9 +83,7 @@ PeriodicSampler::TakeSample(void) if (m_JitterSize > 0.0) { // Using jitter : compute random offset value - m_OffsetValue = - itk::Statistics::MersenneTwisterRandomVariateGenerator::GetInstance() - ->GetUniformVariate(0.0,m_JitterSize); + m_OffsetValue = m_JitterValues[this->m_ChosenElements%m_JitterValues.size()]; } ret = true; } @@ -84,6 +94,7 @@ PeriodicSampler::PeriodicSampler() { this->m_Parameters.Offset = 0UL; this->m_Parameters.MaxJitter = 0UL; + this->m_Parameters.MaxBufferSize = 100000000UL; m_JitterSize = 0.0; m_OffsetValue = 0.0; } diff --git a/Modules/Filtering/Statistics/test/otbSamplerTest.cxx b/Modules/Filtering/Statistics/test/otbSamplerTest.cxx index d50684d5410abb2f28c4269cf66d6c1580a2e1ea..9d29cb5869bd8915e2832c1c060d57879119bcde 100644 --- a/Modules/Filtering/Statistics/test/otbSamplerTest.cxx +++ b/Modules/Filtering/Statistics/test/otbSamplerTest.cxx @@ -72,6 +72,7 @@ int otbPeriodicSamplerTest(int, char *[]) param.Offset = 0; param.MaxJitter = 10; + param.MaxBufferSize = 1000000UL; sampler->SetRate(0.2,50); sampler->SetParameters(param); std::string test2 = RunSampler<otb::PeriodicSampler>(sampler,50); diff --git a/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.h b/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.h index a654b79d75b633632c07d569ff3e8cf75cc0a3ff..c602f1276a7b1131ee07d70af5f192c35bee8073 100644 --- a/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.h +++ b/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.h @@ -148,8 +148,12 @@ private: /** Internal samplers*/ std::vector<SamplerMapType> m_Samplers; - /** Field name to store the FID of the geometry each sample comes from */ + /** Field name to store the FID of the geometry each sample comes from. + * When this name is empty, no FID is stored. */ std::string m_OriginFieldName; + + /** Flag to enable/disable origin FID in outputs */ + bool m_UseOriginField; }; /** diff --git a/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.txx b/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.txx index d2e6847e1f64bf95a19d6ba9c12d1023632c4fa2..d881897996d7f79e4dce1fddd4b2add79c04db6d 100644 --- a/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.txx +++ b/Modules/Learning/Sampling/include/otbOGRDataToSamplePositionFilter.txx @@ -31,6 +31,7 @@ PersistentOGRDataToSamplePositionFilter<TInputImage,TMaskImage,TSampler> { this->SetNumberOfRequiredOutputs(2); m_OriginFieldName = std::string("originfid"); + m_UseOriginField = true; } template<class TInputImage, class TMaskImage, class TSampler> @@ -51,7 +52,11 @@ PersistentOGRDataToSamplePositionFilter<TInputImage,TMaskImage,TSampler> // Add an extra field for the original FID this->ClearAdditionalFields(); - this->CreateAdditionalField(this->GetOriginFieldName(),OFTInteger,12); + m_UseOriginField = (this->GetOriginFieldName().size() > 0); + if (m_UseOriginField) + { + this->CreateAdditionalField(this->GetOriginFieldName(),OFTInteger,12); + } // compute label mapping this->ComputeClassPartition(); @@ -179,7 +184,10 @@ PersistentOGRDataToSamplePositionFilter<TInputImage,TMaskImage,TSampler> ogr::Layer outputLayer = this->GetInMemoryOutput(threadid,i); ogr::Feature feat(outputLayer.GetLayerDefn()); feat.SetFrom(feature); - feat[this->GetOriginFieldName()].SetValue(static_cast<int>(feature.GetFID())); + if (m_UseOriginField) + { + feat[this->GetOriginFieldName()].SetValue(static_cast<int>(feature.GetFID())); + } feat.SetGeometry(&ogrTmpPoint); outputLayer.CreateFeature(feat); break; diff --git a/Modules/Wrappers/ApplicationEngine/include/otbWrapperCompositeApplication.h b/Modules/Wrappers/ApplicationEngine/include/otbWrapperCompositeApplication.h index 5af08d98241f6795147832f9f566642ee325d448..caa6b6754cfcf4cb9949ef025873ee53be4f39db 100644 --- a/Modules/Wrappers/ApplicationEngine/include/otbWrapperCompositeApplication.h +++ b/Modules/Wrappers/ApplicationEngine/include/otbWrapperCompositeApplication.h @@ -84,6 +84,13 @@ protected: */ bool AddApplication(std::string appType, std::string key, std::string desc); + /** + * Method to remove all internal applications. Application deriving from + * CompositeApplication should call this method at the begining of their + * DoInit(). + */ + void ClearApplications(); + /** * Connect two existing parameters together. The first parameter will point to * the second parameter. diff --git a/Modules/Wrappers/ApplicationEngine/src/otbWrapperCompositeApplication.cxx b/Modules/Wrappers/ApplicationEngine/src/otbWrapperCompositeApplication.cxx index a3a264a02dcbbe290f2c59b6c8981b2fcf0c06c8..09f9e0fc3034ddcd1d26723f41e1c48dc281d844 100644 --- a/Modules/Wrappers/ApplicationEngine/src/otbWrapperCompositeApplication.cxx +++ b/Modules/Wrappers/ApplicationEngine/src/otbWrapperCompositeApplication.cxx @@ -69,6 +69,13 @@ CompositeApplication return true; } +void +CompositeApplication +::ClearApplications() +{ + m_AppContainer.clear(); +} + bool CompositeApplication ::Connect(std::string fromKey, std::string toKey) @@ -169,16 +176,40 @@ CompositeApplication ::ExecuteInternal(std::string key) { otbAppLogINFO(<< GetInternalAppDescription(key) <<"..."); - GetInternalApplication(key)->Execute(); - otbAppLogINFO(<< "\n" << m_Oss.str()); - m_Oss.str(std::string("")); + try + { + GetInternalApplication(key)->Execute(); + } + catch(...) + { + this->GetLogger()->Write( itk::LoggerBase::FATAL, std::string("\n") + m_Oss.str() ); + throw; + } + if(!m_Oss.str().empty()) + { + otbAppLogINFO(<< "\n" << m_Oss.str()); + m_Oss.str(std::string("")); + } } void CompositeApplication ::UpdateInternalParameters(std::string key) { - GetInternalApplication(key)->UpdateParameters(); + try + { + GetInternalApplication(key)->UpdateParameters(); + } + catch(...) + { + this->GetLogger()->Write( itk::LoggerBase::FATAL, std::string("\n") + m_Oss.str() ); + throw; + } + if(!m_Oss.str().empty()) + { + otbAppLogINFO(<< "\n" << m_Oss.str()); + m_Oss.str(std::string("")); + } } } // end namespace Wrapper