Commit e459c96c authored by Ludovic Hussonnois's avatar Ludovic Hussonnois

MRG: Merge remote-tracking branch 'remotes/origin/unsupervised_classif' into develop

parents aff350aa 61c04dc3
......@@ -46,7 +46,7 @@ if(NOT OTB_USE_OPENCV)
SET(BANNED_HEADERS "${BANNED_HEADERS} otbDecisionTreeMachineLearningModelFactory.h otbDecisionTreeMachineLearningModel.h otbKNearestNeighborsMachineLearningModelFactory.h otbKNearestNeighborsMachineLearningModel.h otbRandomForestsMachineLearningModelFactory.h otbRandomForestsMachineLearningModel.h otbSVMMachineLearningModelFactory.h otbSVMMachineLearningModel.h otbGradientBoostedTreeMachineLearningModelFactory.h otbGradientBoostedTreeMachineLearningModel.h otbBoostMachineLearningModelFactory.h otbBoostMachineLearningModel.h otbNeuralNetworkMachineLearningModelFactory.h otbNeuralNetworkMachineLearningModel.h otbNormalBayesMachineLearningModelFactory.h otbNormalBayesMachineLearningModel.h otbRequiresOpenCVCheck.h otbOpenCVUtils.h otbCvRTreesWrapper.h")
endif()
if(NOT OTB_USE_SHARK)
SET(BANNED_HEADERS "${BANNED_HEADERS} otbSharkRandomForestsMachineLearningModel.h otbSharkRandomForestsMachineLearningModel.txx otbSharkUtils.h otbRequiresSharkCheck.h otbSharkRandomForestsMachineLearningModelFactory.h")
SET(BANNED_HEADERS "${BANNED_HEADERS} otbSharkRandomForestsMachineLearningModel.h otbSharkRandomForestsMachineLearningModel.txx otbSharkUtils.h otbRequiresSharkCheck.h otbSharkRandomForestsMachineLearningModelFactory.h otbSharkKMeansMachineLearningModel.h otbSharkKMeansMachineLearningModel.txx otbSharkKMeansMachineLearningModelFactory.h otbSharkKMeansMachineLearningModelFactory.txx")
endif()
if(NOT OTB_USE_LIBSVM)
SET(BANNED_HEADERS "${BANNED_HEADERS} otbLibSVMMachineLearningModel.h otbLibSVMMachineLearningModelFactory.h")
......@@ -64,7 +64,7 @@ endif()
macro( otb_module_headertest _name )
if( NOT ${_name}_THIRD_PARTY
if( NOT ${_name}_THIRD_PARTY
AND EXISTS ${${_name}_SOURCE_DIR}/include
AND PYTHON_EXECUTABLE
AND NOT (PYTHON_VERSION_STRING VERSION_LESS 2.6)
......
......@@ -70,7 +70,6 @@ otb_create_application(
SOURCES otbTrainVectorClassifier.cxx
LINK_LIBRARIES ${${otb-module}_LIBRARIES})
otb_create_application(
NAME ComputeConfusionMatrix
SOURCES otbComputeConfusionMatrix.cxx
......
......@@ -102,7 +102,22 @@ public:
typedef typename ModelType::TargetSampleType TargetSampleType;
typedef typename ModelType::TargetListSampleType TargetListSampleType;
typedef typename ModelType::TargetValueType TargetValueType;
itkGetConstReferenceMacro(SupervisedClassifier, std::vector<std::string>);
itkGetConstReferenceMacro(UnsupervisedClassifier, std::vector<std::string>);
enum ClassifierCategory{
Supervised,
Unsupervised
};
/**
* Retrieve the classifier category (supervisde or unsupervised)
* based on the select algorithm from the classifier choice.
* @return ClassifierCategory the classifier category
*/
ClassifierCategory GetClassifierCategory();
protected:
LearningApplicationBase();
......@@ -120,15 +135,23 @@ protected:
std::string modelPath);
/** Init method that creates all the parameters for machine learning models */
void DoInit();
void DoInit() ITK_OVERRIDE;
/** Flag to switch between classification and regression mode.
* False by default, child classes may change it in their constructor */
bool m_RegressionFlag;
private:
/** Specific Init and Train methods for each machine learning model */
/** Init Parameters for Supervised Classifier */
void InitSupervisedClassifierParams();
std::vector<std::string> m_SupervisedClassifier;
/** Init Parameters for Unsupervised Classifier */
void InitUnsupervisedClassifierParams();
std::vector<std::string> m_UnsupervisedClassifier;
//@{
#ifdef OTB_USE_LIBSVM
void InitLibSVMParams();
......@@ -179,6 +202,10 @@ private:
void TrainSharkRandomForests(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
void InitSharkKMeansParams();
void TrainSharkKMeans(typename ListSampleType::Pointer trainingListSample,
typename TargetListSampleType::Pointer trainingLabeledListSample,
std::string modelPath);
#endif
//@}
};
......@@ -203,6 +230,7 @@ private:
#endif
#ifdef OTB_USE_SHARK
#include "otbTrainSharkRandomForests.txx"
#include "otbTrainSharkKMeans.txx"
#endif
#endif
......
......@@ -54,8 +54,33 @@ LearningApplicationBase<TInputValue,TOutputValue>
AddParameter(ParameterType_Choice, "classifier", "Classifier to use for the training");
SetParameterDescription("classifier", "Choice of the classifier to use for the training.");
InitSupervisedClassifierParams();
m_SupervisedClassifier = GetChoiceKeys("classifier");
InitUnsupervisedClassifierParams();
std::vector<std::string> allClassifier = GetChoiceKeys("classifier");
m_UnsupervisedClassifier.assign(allClassifier.begin() + m_SupervisedClassifier.size(), allClassifier.end());
}
template <class TInputValue, class TOutputValue>
typename LearningApplicationBase<TInputValue,TOutputValue>::ClassifierCategory
LearningApplicationBase<TInputValue,TOutputValue>
::GetClassifierCategory()
{
bool foundUnsupervised =
std::find(m_UnsupervisedClassifier.begin(), m_UnsupervisedClassifier.end(),
GetParameterString("classifier")) != m_UnsupervisedClassifier.end();
return foundUnsupervised ? Unsupervised : Supervised;
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitSupervisedClassifierParams()
{
//Group LibSVM
#ifdef OTB_USE_LIBSVM
#ifdef OTB_USE_LIBSVM
InitLibSVMParams();
#endif
......@@ -81,7 +106,16 @@ LearningApplicationBase<TInputValue,TOutputValue>
#ifdef OTB_USE_SHARK
InitSharkRandomForestsParams();
#endif
}
template <class TInputValue, class TOutputValue>
void
LearningApplicationBase<TInputValue,TOutputValue>
::InitUnsupervisedClassifierParams()
{
#ifdef OTB_USE_SHARK
InitSharkKMeansParams();
#endif
}
template <class TInputValue, class TOutputValue>
......@@ -151,6 +185,14 @@ LearningApplicationBase<TInputValue,TOutputValue>
otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
#endif
}
else if(modelName == "sharkkm")
{
#ifdef OTB_USE_SHARK
TrainSharkKMeans( trainingListSample, trainingLabeledListSample, modelPath );
#else
otbAppLogFATAL("Module SharkLearning is not installed. You should consider turning OTB_USE_SHARK on during cmake configuration.");
#endif
}
else if (modelName == "svm")
{
#ifdef OTB_USE_OPENCV
......
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbTrainImagesBase_h
#define otbTrainImagesBase_h
#include "otbVectorDataFileWriter.h"
#include "otbWrapperCompositeApplication.h"
#include "otbWrapperApplicationFactory.h"
#include "otbStatisticsXMLFileWriter.h"
#include "otbImageToEnvelopeVectorDataFilter.h"
#include "otbSamplingRateCalculator.h"
#include "otbOGRDataToSamplePositionFilter.h"
namespace otb
{
namespace Wrapper
{
/** \class TrainImagesBase
* \brief Base class for the TrainImagesClassifier
*
* This class intends to hold common input/output parameters and
* composite application connection for both supervised and unsupervised
* model training.
*
* \ingroup OTBAppClassification
*/
class TrainImagesBase : public CompositeApplication
{
public:
/** Standard class typedefs. */
typedef TrainImagesBase Self;
typedef CompositeApplication Superclass;
typedef itk::SmartPointer<Self> Pointer;
typedef itk::SmartPointer<const Self> ConstPointer;
/** Standard macro */
itkTypeMacro( TrainImagesBase, Superclass )
/** filters typedefs*/
typedef otb::OGRDataToSamplePositionFilter<FloatVectorImageType, UInt8ImageType, otb::PeriodicSampler> PeriodicSamplerType;
typedef otb::SamplingRateCalculator::MapRateType MapRateType;
protected:
enum SamplingStrategy
{
CLASS, GEOMETRIC
};
struct SamplingRates;
class TrainFileNamesHandler;
/**
* Initialize all the input and output parameter used for the train images
*/
void InitIO();
/**
* Initialize sampling related application and parameters
*/
void InitSampling();
void ShareSamplingParameters();
void ConnectSamplingParameters();
void InitClassification();
void ShareClassificationParams();
void ConnectClassificationParams();
/**
* Compute polygon statistics given provided strategy with PolygonClassStatistics class
* \param imageList list of input images
* \param vectorFileNames list of input vector file names
* \param statisticsFileNames list of out
*/
void ComputePolygonStatistics(FloatVectorImageListType *imageList, const std::vector<std::string> &vectorFileNames,
const std::vector<std::string> &statisticsFileNames);
/**
* Compute final maximum training and validation
* \param dedicatedValidation
* \return SamplingRates final maximum training and final maximum validation
*/
SamplingRates ComputeFinalMaximumSamplingRates(bool dedicatedValidation);
/**
* Compute rates using MultiImageSamplingRate application
* \param statisticsFileNames
* \param ratesFileName
* \param maximum final maximum value computed by ComputeFinalMaximumSamplingRates
* \sa ComputeFinalMaximumSamplingRates
*/
void ComputeSamplingRate(const std::vector<std::string> &statisticsFileNames,
const std::string &ratesFileName,
long maximum);
/**
* Train the model with training and optional validation data samples
* \param imageList list of input images
* \param sampleTrainFileNames files names of the training samples
* \param sampleValidationFileNames file names of the validation sample
*/
void TrainModel(FloatVectorImageListType *imageList, const std::vector<std::string> &sampleTrainFileNames,
const std::vector<std::string> &sampleValidationFileNames);
/**
* Select samples by class or by geographic strategy
* \param image
* \param vectorFileName
* \param sampleFileName
* \param statisticsFileName
* \param ratesFileName
* \param strategy
*/
void SelectAndExtractSamples(FloatVectorImageType *image, std::string vectorFileName, std::string sampleFileName,
std::string statisticsFileName, std::string ratesFileName, SamplingStrategy strategy,
std::string selectedField = "");
/**
* Select and extract samples with the SampleSelection and SampleExtraction application.
* \param fileNames
* \param imageList
* \param vectorFileNames
* \param strategy the strategy used for selection (by class or with geometry)
* \param selectedFieldName
*/
void SelectAndExtractTrainSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
std::vector<std::string> vectorFileNames, SamplingStrategy strategy,
std::string selectedFieldName = "");
/**
* Function used to select validation samples based on a defined strategy (geometric in unsupervised mode)
* and extract them. With dedicated validation the 'by class' sampling strategy and statistics are used.
* Otherwise this function split training to validation samples corresponding to sample.vtr percentage.
* or do nothing if this percentage is == 0
* \param fileNames
* \param imageList
* \param validationVectorFileList optional validation vector file for each images
*/
void SelectAndExtractValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList,
const std::vector<std::string> &validationVectorFileList = std::vector<std::string>());
/**
* Function used to split all training samples from all images in a set of training and validation.
* \param fileNames
* \param imageList
* \sa SplitTrainingAndValidationSamples
*/
void SplitTrainingToValidationSamples(const TrainFileNamesHandler &fileNames, FloatVectorImageListType *imageList);
private:
/**
* Function used to split training samples in set of training and validation.
* \param image input image
* \param sampleFileName the input sample file name
* \param sampleTrainFileName the input training file name
* \param sampleValidFileName the input validation file name
* \param ratesTrainFileName the rates file name
*/
void SplitTrainingAndValidationSamples(FloatVectorImageType *image, std::string sampleFileName,
std::string sampleTrainFileName, std::string sampleValidFileName,
std::string ratesTrainFileName);
protected:
struct SamplingRates
{
long int fmt;
long int fmv;
};
/**
* \class TrainFileNamesHandler
* This class is used to store file names requires for the application's input and output.
* And to clear temporary files generated by the applications
* \ingroup OTBAppClassification
*/
class TrainFileNamesHandler
{
public :
void CreateTemporaryFileNames(std::string outModel, size_t nbInputs, bool dedicatedValidation)
{
if( dedicatedValidation )
{
rateTrainOut = outModel + "_ratesTrain.csv";
}
else
{
rateTrainOut = outModel + "_rates.csv";
}
rateValidOut = outModel + "_ratesValid.csv";
for( unsigned int i = 0; i < nbInputs; i++ )
{
std::ostringstream oss;
oss << i + 1;
std::string strIndex( oss.str() );
if( dedicatedValidation )
{
polyStatTrainOutputs.push_back( outModel + "_statsTrain_" + strIndex + ".xml" );
polyStatValidOutputs.push_back( outModel + "_statsValid_" + strIndex + ".xml" );
ratesTrainOutputs.push_back( outModel + "_ratesTrain_" + strIndex + ".csv" );
ratesValidOutputs.push_back( outModel + "_ratesValid_" + strIndex + ".csv" );
sampleOutputs.push_back( outModel + "_samplesTrain_" + strIndex + ".shp" );
}
else
{
polyStatTrainOutputs.push_back( outModel + "_stats_" + strIndex + ".xml" );
ratesTrainOutputs.push_back( outModel + "_rates_" + strIndex + ".csv" );
sampleOutputs.push_back( outModel + "_samples_" + strIndex + ".shp" );
}
sampleTrainOutputs.push_back( outModel + "_samplesTrain_" + strIndex + ".shp" );
sampleValidOutputs.push_back( outModel + "_samplesValid_" + strIndex + ".shp" );
}
}
void clear()
{
for( unsigned int i = 0; i < polyStatTrainOutputs.size(); i++ )
RemoveFile( polyStatTrainOutputs[i] );
for( unsigned int i = 0; i < polyStatValidOutputs.size(); i++ )
RemoveFile( polyStatValidOutputs[i] );
for( unsigned int i = 0; i < ratesTrainOutputs.size(); i++ )
RemoveFile( ratesTrainOutputs[i] );
for( unsigned int i = 0; i < ratesValidOutputs.size(); i++ )
RemoveFile( ratesValidOutputs[i] );
for( unsigned int i = 0; i < sampleOutputs.size(); i++ )
RemoveFile( sampleOutputs[i] );
for( unsigned int i = 0; i < sampleTrainOutputs.size(); i++ )
RemoveFile( sampleTrainOutputs[i] );
for( unsigned int i = 0; i < sampleValidOutputs.size(); i++ )
RemoveFile( sampleValidOutputs[i] );
for( unsigned int i = 0; i < tmpVectorFileList.size(); i++ )
RemoveFile( tmpVectorFileList[i] );
}
public:
std::vector<std::string> polyStatTrainOutputs;
std::vector<std::string> polyStatValidOutputs;
std::vector<std::string> ratesTrainOutputs;
std::vector<std::string> ratesValidOutputs;
std::vector<std::string> sampleOutputs;
std::vector<std::string> sampleTrainOutputs;
std::vector<std::string> sampleValidOutputs;
std::vector<std::string> tmpVectorFileList;
std::string rateValidOut;
std::string rateTrainOut;
private:
bool RemoveFile(std::string &filePath)
{
bool res = true;
if( itksys::SystemTools::FileExists( filePath.c_str() ) )
{
size_t posExt = filePath.rfind( '.' );
if( posExt != std::string::npos && filePath.compare( posExt, std::string::npos, ".shp" ) == 0 )
{
std::string shxPath = filePath.substr( 0, posExt ) + std::string( ".shx" );
std::string dbfPath = filePath.substr( 0, posExt ) + std::string( ".dbf" );
std::string prjPath = filePath.substr( 0, posExt ) + std::string( ".prj" );
RemoveFile( shxPath );
RemoveFile( dbfPath );
RemoveFile( prjPath );
}
res = itksys::SystemTools::RemoveFile( filePath.c_str() );
if( !res )
{
//otbAppLogINFO( <<"Unable to remove file "<<filePath );
}
}
return res;
}
};
};
} // end namespace Wrapper
} // end namespace otb
#ifndef OTB_MANUAL_INSTANTIATION
#include "otbTrainImagesBase.txx"
#endif
#endif //otbTrainImagesBase_h
/*
* Copyright (C) 2005-2017 Centre National d'Etudes Spatiales (CNES)
*
* This file is part of Orfeo Toolbox
*
* https://www.orfeo-toolbox.org/
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef otbTrainImagesBase_txx
#define otbTrainImagesBase_txx
#include "otbTrainImagesBase.h"
namespace otb
{
namespace Wrapper
{
void TrainImagesBase::InitIO()
{
//Group IO
AddParameter( ParameterType_Group, "io", "Input and output data" );
SetParameterDescription( "io", "This group of parameters allows setting input and output data." );
AddParameter( ParameterType_InputImageList, "io.il", "Input Image List" );
SetParameterDescription( "io.il", "A list of input images." );
AddParameter( ParameterType_InputVectorDataList, "io.vd", "Input Vector Data List" );
SetParameterDescription( "io.vd", "A list of vector data to select the training samples." );
MandatoryOn( "io.vd" );
AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" );
EnableParameter( "cleanup" );
SetParameterDescription( "cleanup",
"If activated, the application will try to clean all temporary files it created" );
MandatoryOff( "cleanup" );
}
void TrainImagesBase::InitSampling()
{
AddApplication( "PolygonClassStatistics", "polystat", "Polygon analysis" );
AddApplication( "MultiImageSamplingRate", "rates", "Sampling rates" );
AddApplication( "SampleSelection", "select", "Sample selection" );
AddApplication( "SampleExtraction", "extraction", "Sample extraction" );
// Sampling settings
AddParameter( ParameterType_Group, "sample", "Training and validation samples parameters" );
SetParameterDescription( "sample",
"This group of parameters allows you to set training and validation sample lists parameters." );
AddParameter( ParameterType_Int, "sample.mt", "Maximum training sample size per class" );
SetDefaultParameterInt( "sample.mt", 1000 );
SetParameterDescription( "sample.mt", "Maximum size per class (in pixels) of "
"the training sample list (default = 1000) (no limit = -1). If equal to -1,"
" then the maximal size of the available training sample list per class "
"will be equal to the surface area of the smallest class multiplied by the"
" training sample ratio." );
AddParameter( ParameterType_Int, "sample.mv", "Maximum validation sample size per class" );
SetDefaultParameterInt( "sample.mv", 1000 );
SetParameterDescription( "sample.mv", "Maximum size per class (in pixels) of "
"the validation sample list (default = 1000) (no limit = -1). If equal to -1,"
" then the maximal size of the available validation sample list per class "
"will be equal to the surface area of the smallest class multiplied by the "
"validation sample ratio." );
AddParameter( ParameterType_Int, "sample.bm", "Bound sample number by minimum" );
SetDefaultParameterInt( "sample.bm", 1 );
SetParameterDescription( "sample.bm", "Bound the number of samples for each "
"class by the number of available samples by the smaller class. Proportions "
"between training and validation are respected. Default is true (=1)." );
AddParameter( ParameterType_Float, "sample.vtr", "Training and validation sample ratio" );
SetParameterDescription( "sample.vtr", "Ratio between training and validation samples (0.0 = all training, 1.0 = "
"all validation) (default = 0.5)." );
SetParameterFloat( "sample.vtr", 0.5, false );
SetMaximumParameterFloatValue( "sample.vtr", 1.0 );
SetMinimumParameterFloatValue( "sample.vtr", 0.0 );
// AddParameter( ParameterType_Float, "sample.percent", "Percentage of sample extract from images" );
// SetParameterDescription( "sample.percent", "Percentage of sample extract from images for "
// "training and validation when only images are provided." );
// SetDefaultParameterFloat( "sample.percent", 1.0 );
// SetMinimumParameterFloatValue( "sample.percent", 0.0 );
// SetMaximumParameterFloatValue( "sample.percent", 1.0 );
ShareSamplingParameters();
ConnectSamplingParameters();
}
void TrainImagesBase::ShareSamplingParameters()
{
// hide sampling parameters
//ShareParameter("sample.strategy","rates.strategy");
//ShareParameter("sample.mim","rates.mim");
ShareParameter( "ram", "polystat.ram" );
ShareParameter( "elev", "polystat.elev" );
ShareParameter( "sample.vfn", "polystat.field" );
}
void TrainImagesBase::ConnectSamplingParameters()
{
Connect( "extraction.field", "polystat.field" );
Connect( "extraction.layer", "polystat.layer" );
Connect( "select.ram", "polystat.ram" );
Connect( "extraction.ram", "polystat.ram" );
Connect( "select.field", "polystat.field" );
Connect( "select.layer", "polystat.layer" );
Connect( "select.elev", "polystat.elev" );
Connect( "extraction.in", "select.in" );
Connect( "extraction.vec", "select.out" );
}
void TrainImagesBase::InitClassification()
{
AddApplication( "TrainVectorClassifier", "training", "Model training" );
AddParameter( ParameterType_InputVectorDataList, "io.valid", "Validation Vector Data List" );
SetParameterDescription( "io.valid", "A list of vector data to select the training samples." );
MandatoryOff( "io.valid" );