Commit 2b976962 authored by Cédric Traizet's avatar Cédric Traizet

BUG: input centroids should not be mandatory for all classifiers (and even for sharkkmean)

parent 62529436
Pipeline #3260 passed with stages
in 64 minutes and 22 seconds
......@@ -106,7 +106,7 @@ public:
{
ShareParameter("ram", "polystats.ram");
ShareParameter("sampler", "select.sampler");
ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out");
ShareParameter("centroids.out", "training.classifier.sharkkm.outcentroids");
ShareParameter("vm", "polystats.mask", "Validity Mask", "Validity mask, only non-zero pixels will be used to estimate KMeans modes.");
}
......@@ -248,10 +248,10 @@ public:
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k", GetParameterInt("nc"));
if (IsParameterEnabled("centroids.in") && HasValue("centroids.in"))
{
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in", GetParameterString("centroids.in"));
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.incentroids", GetParameterString("centroids.in"));
GetInternalApplication("training")
->SetParameterString("classifier.sharkkm.centroids.stats", GetInternalApplication("imgstats")->GetParameterString("out"));
->SetParameterString("classifier.sharkkm.cstats", GetInternalApplication("imgstats")->GetParameterString("out"));
}
......
......@@ -46,30 +46,26 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
// Centroid IO
AddParameter(ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters");
SetParameterDescription("classifier.sharkkm.centroids", "Group of parameters for centroids IO.");
// Input centroids
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids");
SetParameterDescription("classifier.sharkkm.centroids.in",
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.incentroids", "User defined input centroids");
SetParameterDescription("classifier.sharkkm.incentroids",
"Input text file containing centroid posistions used to initialize the algorithm. "
"Each centroid must be described by p parameters, p being the number of features in "
"the input vector data, and the number of centroids must be equal to the number of classes "
"(one centroid per line with values separated by spaces).");
MandatoryOff("classifier.sharkkm.centroids");
MandatoryOff("classifier.sharkkm.incentroids");
// Centroid statistics
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file");
SetParameterDescription("classifier.sharkkm.centroids.stats",
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.cstats", "Statistics file");
SetParameterDescription("classifier.sharkkm.cstats",
"A XML file containing mean and standard deviation to center"
"and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroids.stats");
"and reduce the input centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.cstats");
// Output centroids
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.centroids.out");
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.outcentroids");
}
template <class TInputValue, class TOutputValue>
......@@ -88,14 +84,14 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typena
classifier->SetK(k);
// Initialize centroids from file
if (IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in"))
if (IsParameterEnabled("classifier.sharkkm.incentroids") && HasValue("classifier.sharkkm.incentroids"))
{
shark::Data<shark::RealVector> centroidData;
shark::importCSV(centroidData, GetParameterString("classifier.sharkkm.centroids.in"), ' ');
if (HasValue("classifier.sharkkm.centroids.stats"))
shark::importCSV(centroidData, GetParameterString("classifier.sharkkm.incentroids"), ' ');
if (HasValue("classifier.sharkkm.cstats"))
{
auto statisticsReader = otb::StatisticsXMLFileReader<itk::VariableLengthVector<float>>::New();
statisticsReader->SetFileName(GetParameterString("classifier.sharkkm.centroids.stats"));
statisticsReader->SetFileName(GetParameterString("classifier.sharkkm.cstats"));
auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
......@@ -126,8 +122,8 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typena
classifier->Train();
classifier->Save(modelPath);
if (HasValue("classifier.sharkkm.centroids.out"))
classifier->ExportCentroids(GetParameterString("classifier.sharkkm.centroids.out"));
if (HasValue("classifier.sharkkm.outcentroids"))
classifier->ExportCentroids(GetParameterString("classifier.sharkkm.outcentroids"));
}
} // end namespace wrapper
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment