Commit ad4853ad authored by Marina Bertolino's avatar Marina Bertolino

ENH: add internal parameters

parent 9b8096e2
......@@ -57,9 +57,10 @@ private:
// Perform initialization
ClearApplications();
initKMIO();
InitKMSampling();
InitKMClassification();
// initialisations parameters and synchronizes parameters
initKMParams();
if ( HasValue("vm") ) ConnectKMClassificationMask();
AddParameter(ParameterType_OutputImage, "out", "Output Image");
SetParameterDescription("out", "Output image containing the class indexes.");
......@@ -67,6 +68,8 @@ private:
AddRAMParameter(); // TODO verifier si les RAMParameter sont bien tous cablés
AddRANDParameter();
// Doc example parameter settings
SetDocExampleParameterValue("in", "QB_1_ortho.tif");
SetDocExampleParameterValue("ts", "1000");
......@@ -80,7 +83,6 @@ private:
void DoUpdateParameters() ITK_OVERRIDE
{
//UpdateInternalParameters("");
}
void DoExecute() ITK_OVERRIDE
......@@ -92,37 +94,42 @@ private:
fileNames.CreateTemporaryFileNames(GetParameterString( "out" ));
otbAppLogINFO(" init filename : " << fileNames.tmpVectorFile); //RM
// Create an image envelope
ComputeImageEnvelope(fileNames.tmpVectorFile);
// Add a new field at the ImageEnvelope output file
ComputeAddField(fileNames.tmpVectorFile, fieldName);
// Compute PolygonStatistics app
UpdateKMPolygonClassStatisticsParameters(fileNames.tmpVectorFile);
ComputePolygonStatistics(fileNames.polyStatOutput, fieldName);
const double theoricNBSamplesForKMeans = GetParameterInt("ts");
const double upperThresholdNBSamplesForKMeans = 1000 * 1000;
const double actualNBSamplesForKMeans = std::min(theoricNBSamplesForKMeans,
upperThresholdNBSamplesForKMeans);
// Compute number of sample max for KMeans
const int theoricNBSamplesForKMeans = GetParameterInt("ts");
const int upperThresholdNBSamplesForKMeans = 1000 * 1000;
const int actualNBSamplesForKMeans = std::min(theoricNBSamplesForKMeans,
upperThresholdNBSamplesForKMeans);
otbAppLogINFO(<< actualNBSamplesForKMeans << " is the maximum sample size that will be used." \
<< std::endl);
// Compute SampleSelection and SampleExtraction app
SelectAndExtractSamples(fileNames.sampleSelectOutput, fileNames.polyStatOutput,
fieldName, fileNames.sampleExtractOutput);
fieldName, fileNames.sampleExtractOutput,
actualNBSamplesForKMeans);
// todo RM
//std::vector<std::string> sampleTrainFileName = {fileNames.sampleSelectOutput};
//UpdateTrainKMModelParameters(sampleTrainFileName);
// Compute a train model with TrainVectorClassifier app
TrainKMModel(GetParameterImage("in"), fileNames.sampleExtractOutput,
fileNames.modelFile);
TrainKMModel(GetParameterImage("in"), fileNames.sampleExtractOutput);
// Compute a classification of the input image according to a model file
KMeansClassif();
/* TODO cleanup
// cleanup
// remove all tempory files
if( IsParameterEnabled( "cleanup" ) )
{
otbAppLogINFO( <<"Final clean-up ..." );
fileNames.clear(); // TODO create clear()
fileNames.clear();
}
*/
}
private :
......@@ -132,13 +139,7 @@ private :
GetInternalApplication( "polystats" )->SetParameterString( "vec", vectorFileName, false );
UpdateInternalParameters( "polystats" );
}
/*
void UpdateTrainKMModelParameters(const std::vector<std::string> &sampleTrainFileName)
{
GetInternalApplication( "training" )->SetParameterStringList( "io.vd", sampleTrainFileName, false );
UpdateInternalParameters( "training" );
}
*/
};
......
......@@ -57,7 +57,7 @@ public:
protected:
class KMeansFileNamesHandler;
void initKMIO();
void initKMParams();
void InitKMSampling();
void InitKMClassification();
......@@ -65,6 +65,7 @@ public:
void ShareKMClassificationParams();
void ConnectKMSamplingParams();
void ConnectKMClassificationParams();
void ConnectKMClassificationMask();
/**
* Create a vector file (envelope image)
......@@ -88,7 +89,7 @@ public:
const std::string &fieldName);
/**
* Select samples by class or by geographic strategy
* Select samples by constant strategy
* \param sampleFileName
* \param statisticsFileName
* \param fieldName
......@@ -97,14 +98,22 @@ public:
void SelectAndExtractSamples(std::string sampleFileName,
std::string statisticsFileName,
std::string fieldName,
std::string sampleExtractFileName);
std::string sampleExtractFileName,
int NBSamples);
/**
* Train the model with training
* Train the model
* \param image input image
* \param sampleTrainFileName
*/
void TrainKMModel(FloatVectorImageType *image, std::string sampleTrainFileName);
void TrainKMModel(FloatVectorImageType *image,
std::string sampleTrainFileName,
std::string modelFileName);
/**
* Performs a classification of the input image according to a model file
*/
void KMeansClassif();
/**
* \class KMeansFileNamesHandler
......@@ -112,7 +121,7 @@ public:
* And to clear temporary files generated by the applications
* \ingroup OTBAppClassification
*/
class KMeansFileNamesHandler
class KMeansFileNamesHandler
{
public :
void CreateTemporaryFileNames(std::string outPath)
......@@ -121,13 +130,49 @@ public:
polyStatOutput = outPath + "_polyStats.xml";
sampleSelectOutput = outPath + "_sampleSelect.shp";
sampleExtractOutput = outPath + "_sampleExtract.shp";
modelFile = outPath + "_model.txt";
}
void clear()
{
RemoveFile(tmpVectorFile);
RemoveFile(polyStatOutput);
RemoveFile(sampleSelectOutput);
RemoveFile(sampleExtractOutput);
RemoveFile(modelFile);
}
public:
std::string tmpVectorFile;
std::string polyStatOutput;
std::string sampleSelectOutput;
std::string sampleExtractOutput;
std::string modelFile;
private:
bool RemoveFile(std::string &filePath)
{
bool res = true;
if( itksys::SystemTools::FileExists( filePath.c_str() ) )
{
size_t posExt = filePath.rfind( '.' );
if( posExt != std::string::npos && filePath.compare( posExt, std::string::npos, ".shp" ) == 0 )
{
std::string shxPath = filePath.substr( 0, posExt ) + std::string( ".shx" );
std::string dbfPath = filePath.substr( 0, posExt ) + std::string( ".dbf" );
std::string prjPath = filePath.substr( 0, posExt ) + std::string( ".prj" );
RemoveFile( shxPath );
RemoveFile( dbfPath );
RemoveFile( prjPath );
}
res = itksys::SystemTools::RemoveFile( filePath.c_str() );
if( !res )
{
//otbAppLogINFO( <<"Unable to remove file "<<filePath );
}
}
return res;
}
};
};
......
......@@ -30,17 +30,16 @@ namespace Wrapper
// todo RM ALL std::cout
void ClassKMeansBase::initKMIO()
void ClassKMeansBase::initKMParams()
{
/*
AddParameter(ParameterType_InputImage, "in", "Input Image");
SetParameterDescription("in", "Input image to classify.");
*/
AddParameter( ParameterType_Empty, "cleanup", "Temporary files cleaning" );
EnableParameter( "cleanup" );
SetParameterDescription( "cleanup",
"If activated, the application will try to clean all temporary files it created" );
MandatoryOff( "cleanup" );
InitKMSampling();
InitKMClassification();
}
void ClassKMeansBase::InitKMSampling()
......@@ -84,7 +83,8 @@ void ClassKMeansBase::InitKMSampling()
void ClassKMeansBase::InitKMClassification()
{
AddApplication( "TrainVectorClassifier", "training", "Model training" );
AddApplication("TrainVectorClassifier", "training", "Model training");
AddApplication("ImageClassifier", "classif", "Performs a classification of the input image");
ShareKMClassificationParams();
ConnectKMClassificationParams();
......@@ -94,13 +94,12 @@ void ClassKMeansBase::ShareKMSamplingParameters()
{
ShareParameter("in", "imgenvelop.in");
ShareParameter("vm", "select.mask");
ShareParameter( "ram", "polystats.ram");
ShareParameter("ram", "polystats.ram");
}
void ClassKMeansBase::ShareKMClassificationParams()
{
// TODO
//ShareParameter( "classifier", "training.classifier" );
ShareParameter("out", "classif.out");
}
void ClassKMeansBase::ConnectKMSamplingParams()
......@@ -109,19 +108,25 @@ void ClassKMeansBase::ConnectKMSamplingParams()
Connect("select.in", "polystats.in");
Connect("select.vec", "polystats.vec");
Connect( "select.ram", "polystats.ram" );
Connect("select.ram", "polystats.ram");
Connect("extraction.in", "select.in");
Connect("extraction.field", "select.field");
Connect("extraction.vec", "select.out");
Connect( "extraction.ram", "polystats.ram" );
Connect("extraction.ram", "polystats.ram");
}
void ClassKMeansBase::ConnectKMClassificationParams()
{
Connect("training.cfield", "extraction.field");
Connect("training.io.stats", "polystats.out");
// TODO Connect Classification
Connect("classif.in", "imgenvelop.in");
Connect("classif.model", "training.io.out");
}
void ClassKMeansBase::ConnectKMClassificationMask()
{
Connect("classif.mask", "select.mask");
}
void ClassKMeansBase::ComputeImageEnvelope(const std::string &vectorFileName)
......@@ -160,10 +165,6 @@ void ClassKMeansBase::ComputeAddField(const std::string &vectorFileName,
if (err != OGRERR_NONE)
itkExceptionMacro(<< "Unable to commit transaction for OGR layer " << layer.ogr().GetName() << ".");
ogrDS->SyncToDisk();
/*
// close input data source
source->Clear();
*/
}
void ClassKMeansBase::ComputePolygonStatistics(const std::string &statisticsFileName,
......@@ -181,40 +182,46 @@ void ClassKMeansBase::ComputePolygonStatistics(const std::string &statisticsFile
void ClassKMeansBase::SelectAndExtractSamples(std::string sampleFileName,
std::string statisticsFileName,
std::string fieldName,
std::string sampleExtractFileName)
std::string sampleExtractFileName,
int NBSamples)
{
/* SampleSelection */
std::cout << "Select init ..." << std::endl;
//GetInternalApplication( "select" )->SetParameterInputImage( "in", image );
GetInternalApplication( "select" )->SetParameterString( "out", sampleFileName, false );
GetInternalApplication("select")->SetParameterString("out", sampleFileName, false);
UpdateInternalParameters( "select" );
GetInternalApplication( "select" )->SetParameterString( "instats", statisticsFileName, false );
GetInternalApplication( "select" )->SetParameterString( "field", fieldName, false );
UpdateInternalParameters("select");
GetInternalApplication("select")->SetParameterString("instats", statisticsFileName, false);
GetInternalApplication("select")->SetParameterString("field", fieldName, false);
GetInternalApplication("select" )->SetParameterString("sampler", "random", false);
GetInternalApplication( "select" )->SetParameterString("strategy", "constant", false);
GetInternalApplication( "select" )->SetParameterInt("strategy.constant.nb", GetParameterInt("ts"), false);
GetInternalApplication("select")->SetParameterString("sampler", "random", false);
GetInternalApplication("select")->SetParameterString("strategy", "constant", false);
GetInternalApplication("select")->SetParameterInt("strategy.constant.nb", NBSamples, false);
std::cout << "select.field = " << GetInternalApplication( "select" )->GetParameterString( "field" ) << std::endl;
std::cout << "select.out = " << GetInternalApplication( "select" )->GetParameterString( "out" ) << std::endl;
// TODO if GetParameterInt("rand") is not defined, default value
GetInternalApplication("select")->SetParameterInt("rand", GetParameterInt("rand"), false);
std::cout << "select.field = " << GetInternalApplication("select")->GetParameterString("field") << std::endl;
std::cout << "select.out = " << GetInternalApplication("select")->GetParameterString("out") << std::endl;
// select sample positions
ExecuteInternal( "select" );
ExecuteInternal("select");
UpdateInternalParameters( "extraction" );
std::cout << "extraction.field =" << GetInternalApplication( "extraction" )->GetParameterString( "field") << std::endl;
/* SampleExtraction */
UpdateInternalParameters("extraction");
std::cout << "extraction.field =" << GetInternalApplication("extraction")->GetParameterString("field") << std::endl;
GetInternalApplication( "extraction" )->SetParameterString( "outfield", "prefix", false );
GetInternalApplication( "extraction" )->SetParameterString( "outfield.prefix.name", "value_", false );
GetInternalApplication("extraction")->SetParameterString("outfield", "prefix", false);
GetInternalApplication("extraction")->SetParameterString("outfield.prefix.name", "value_", false);
GetInternalApplication( "extraction" )->SetParameterString( "out", sampleExtractFileName, false);
GetInternalApplication("extraction")->SetParameterString("out", sampleExtractFileName, false);
std::cout << "extraction.out = " << sampleExtractFileName << std::endl;
// extract sample descriptors
GetInternalApplication( "extraction" )->ExecuteAndWriteOutput();
GetInternalApplication("extraction")->ExecuteAndWriteOutput();
}
void ClassKMeansBase::TrainKMModel(FloatVectorImageType *image,
std::string sampleTrainFileName)
std::string sampleTrainFileName,
std::string modelFileName)
{
std::cout << "init train model ..." << std::endl;
......@@ -233,17 +240,28 @@ void ClassKMeansBase::TrainKMModel(FloatVectorImageType *image,
std::cout << "feat : " << std::string(selectPrefix + oss.str()) << std::endl;
selectedNames.push_back( selectPrefix + oss.str() );
}
GetInternalApplication("training")->SetParameterStringList("feat", selectedNames, false);
GetInternalApplication( "training" )->SetParameterStringList("feat", selectedNames, false);
/* TODO test sans, a enlever
GetInternalApplication("training")->SetParameterString("classifier", "sharkkm", false);
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.maxiter",
GetParameterInt("maxit"), false);
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k",
GetParameterInt("nc"), false);
*/
GetInternalApplication("training")->SetParameterString("io.out", modelFileName, false);
// todo RM
std::cout << "training.io.out : " << GetInternalApplication("training")->GetParameterString("io.out") << std::endl;
ExecuteInternal( "training" );
}
void ClassKMeansBase::KMeansClassif()
{
std::cout << "classification ... " << std::endl;
std::cout << "classif.in : " << GetInternalApplication("classif")->GetParameterString("in") << std::endl;
std::cout << "classif.model : " << GetInternalApplication("classif")->GetParameterString("model") << std::endl;
std::cout << "classif.out : " << GetInternalApplication("classif")->GetParameterString("out") << std::endl;
ExecuteInternal( "classif" );
}
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment