Skip to content
Snippets Groups Projects
Commit 2b976962 authored by Cédric Traizet's avatar Cédric Traizet
Browse files

BUG: input centroids should not be mandatory for all classifiers (and even for sharkkmean)

parent 62529436
No related branches found
No related tags found
No related merge requests found
...@@ -106,7 +106,7 @@ public: ...@@ -106,7 +106,7 @@ public:
{ {
ShareParameter("ram", "polystats.ram"); ShareParameter("ram", "polystats.ram");
ShareParameter("sampler", "select.sampler"); ShareParameter("sampler", "select.sampler");
ShareParameter("centroids.out", "training.classifier.sharkkm.centroids.out"); ShareParameter("centroids.out", "training.classifier.sharkkm.outcentroids");
ShareParameter("vm", "polystats.mask", "Validity Mask", "Validity mask, only non-zero pixels will be used to estimate KMeans modes."); ShareParameter("vm", "polystats.mask", "Validity Mask", "Validity mask, only non-zero pixels will be used to estimate KMeans modes.");
} }
...@@ -248,10 +248,10 @@ public: ...@@ -248,10 +248,10 @@ public:
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k", GetParameterInt("nc")); GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k", GetParameterInt("nc"));
if (IsParameterEnabled("centroids.in") && HasValue("centroids.in")) if (IsParameterEnabled("centroids.in") && HasValue("centroids.in"))
{ {
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids.in", GetParameterString("centroids.in")); GetInternalApplication("training")->SetParameterString("classifier.sharkkm.incentroids", GetParameterString("centroids.in"));
GetInternalApplication("training") GetInternalApplication("training")
->SetParameterString("classifier.sharkkm.centroids.stats", GetInternalApplication("imgstats")->GetParameterString("out")); ->SetParameterString("classifier.sharkkm.cstats", GetInternalApplication("imgstats")->GetParameterString("out"));
} }
......
...@@ -46,30 +46,26 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams() ...@@ -46,30 +46,26 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class"); SetParameterDescription("classifier.sharkkm.k", "The number of classes used for the kmeans algorithm. Default set to 2 class");
SetMinimumParameterIntValue("classifier.sharkkm.k", 2); SetMinimumParameterIntValue("classifier.sharkkm.k", 2);
// Centroid IO
AddParameter(ParameterType_Group, "classifier.sharkkm.centroids", "Centroids IO parameters");
SetParameterDescription("classifier.sharkkm.centroids", "Group of parameters for centroids IO.");
// Input centroids // Input centroids
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.in", "User definied input centroids"); AddParameter(ParameterType_InputFilename, "classifier.sharkkm.incentroids", "User defined input centroids");
SetParameterDescription("classifier.sharkkm.centroids.in", SetParameterDescription("classifier.sharkkm.incentroids",
"Input text file containing centroid posistions used to initialize the algorithm. " "Input text file containing centroid posistions used to initialize the algorithm. "
"Each centroid must be described by p parameters, p being the number of features in " "Each centroid must be described by p parameters, p being the number of features in "
"the input vector data, and the number of centroids must be equal to the number of classes " "the input vector data, and the number of centroids must be equal to the number of classes "
"(one centroid per line with values separated by spaces)."); "(one centroid per line with values separated by spaces).");
MandatoryOff("classifier.sharkkm.centroids"); MandatoryOff("classifier.sharkkm.incentroids");
// Centroid statistics // Centroid statistics
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids.stats", "Statistics file"); AddParameter(ParameterType_InputFilename, "classifier.sharkkm.cstats", "Statistics file");
SetParameterDescription("classifier.sharkkm.centroids.stats", SetParameterDescription("classifier.sharkkm.cstats",
"A XML file containing mean and standard deviation to center" "A XML file containing mean and standard deviation to center"
"and reduce the centroids before the KMeans algorithm, produced by ComputeImagesStatistics application."); "and reduce the input centroids before the KMeans algorithm, produced by ComputeImagesStatistics application.");
MandatoryOff("classifier.sharkkm.centroids.stats"); MandatoryOff("classifier.sharkkm.cstats");
// Output centroids // Output centroids
AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.centroids.out", "Output centroids text file"); AddParameter(ParameterType_OutputFilename, "classifier.sharkkm.outcentroids", "Output centroids text file");
SetParameterDescription("classifier.sharkkm.centroids.out", "Output text file containing centroids after the kmean algorithm."); SetParameterDescription("classifier.sharkkm.outcentroids", "Output text file containing centroids after the kmean algorithm.");
MandatoryOff("classifier.sharkkm.centroids.out"); MandatoryOff("classifier.sharkkm.outcentroids");
} }
template <class TInputValue, class TOutputValue> template <class TInputValue, class TOutputValue>
...@@ -88,14 +84,14 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typena ...@@ -88,14 +84,14 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typena
classifier->SetK(k); classifier->SetK(k);
// Initialize centroids from file // Initialize centroids from file
if (IsParameterEnabled("classifier.sharkkm.centroids.in") && HasValue("classifier.sharkkm.centroids.in")) if (IsParameterEnabled("classifier.sharkkm.incentroids") && HasValue("classifier.sharkkm.incentroids"))
{ {
shark::Data<shark::RealVector> centroidData; shark::Data<shark::RealVector> centroidData;
shark::importCSV(centroidData, GetParameterString("classifier.sharkkm.centroids.in"), ' '); shark::importCSV(centroidData, GetParameterString("classifier.sharkkm.incentroids"), ' ');
if (HasValue("classifier.sharkkm.centroids.stats")) if (HasValue("classifier.sharkkm.cstats"))
{ {
auto statisticsReader = otb::StatisticsXMLFileReader<itk::VariableLengthVector<float>>::New(); auto statisticsReader = otb::StatisticsXMLFileReader<itk::VariableLengthVector<float>>::New();
statisticsReader->SetFileName(GetParameterString("classifier.sharkkm.centroids.stats")); statisticsReader->SetFileName(GetParameterString("classifier.sharkkm.cstats"));
auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean"); auto meanMeasurementVector = statisticsReader->GetStatisticVectorByName("mean");
auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev"); auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
...@@ -126,8 +122,8 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typena ...@@ -126,8 +122,8 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(typena
classifier->Train(); classifier->Train();
classifier->Save(modelPath); classifier->Save(modelPath);
if (HasValue("classifier.sharkkm.centroids.out")) if (HasValue("classifier.sharkkm.outcentroids"))
classifier->ExportCentroids(GetParameterString("classifier.sharkkm.centroids.out")); classifier->ExportCentroids(GetParameterString("classifier.sharkkm.outcentroids"));
} }
} // end namespace wrapper } // end namespace wrapper
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment