...
 
Commits (5)
......@@ -81,9 +81,13 @@ protected:
SetDefaultParameterInt("maxit", 1000);
MandatoryOff("maxit");
AddParameter(ParameterType_String, "incentroid", "Maximum number of iterations");
SetParameterDescription("incentroid", "Maximum number of iterations for the learning step.");
MandatoryOff("incentroid");
AddParameter(ParameterType_String, "inmeans", "Maximum number of iterations");
SetParameterDescription("inmeans", "Maximum number of iterations for the learning step.");
MandatoryOff("inmeans");
AddParameter(ParameterType_Bool, "normalizeinmeans", "Number of classes");
SetParameterDescription("normalizeinmeans", "Number of modes, which will be used to generate class membership.");
SetDefaultParameterInt("normalizeinmeans", true);
AddParameter(ParameterType_OutputFilename, "outmeans", "Centroid filename");
SetParameterDescription("outmeans", "Output text file containing centroid positions");
......@@ -252,8 +256,15 @@ protected:
GetParameterInt("maxit"));
GetInternalApplication("training")->SetParameterInt("classifier.sharkkm.k",
GetParameterInt("nc"));
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.incentroid",
GetParameterString("incentroid"));
if(IsParameterEnabled("inmeans") && HasValue("inmeans"))
{
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroids",
GetParameterString("inmeans"));
if(GetParameterInt("normalizeinmeans"))
GetInternalApplication("training")->SetParameterString("classifier.sharkkm.centroidstats",
GetInternalApplication("imgstats")->GetParameterString("out"));
}
if( IsParameterEnabled("rand"))
GetInternalApplication("training")->SetParameterInt("rand", GetParameterInt("rand"));
......
......@@ -52,9 +52,9 @@ void LearningApplicationBase<TInputValue, TOutputValue>::InitSharkKMeansParams()
MandatoryOff("classifier.sharkkm.centroidstats");
// Number of classes
AddParameter(ParameterType_String, "classifier.sharkkm.centroids", "Number of classes for the kmeans algorithm");
SetParameterDescription("classifier.sharkkm.incentroid", "The number of classes used for the kmeans algorithm. Default set to 2 class");
MandatoryOff("classifier.sharkkm.incentroid");
AddParameter(ParameterType_InputFilename, "classifier.sharkkm.centroids", "Number of classes for the kmeans algorithm");
SetParameterDescription("classifier.sharkkm.centroids", "The number of classes used for the kmeans algorithm. Default set to 2 class");
MandatoryOff("classifier.sharkkm.centroids");
}
template<class TInputValue, class TOutputValue>
......@@ -73,11 +73,11 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
classifier->SetK( k );
// Initialize centroids from file
if(HasValue("classifier.sharkkm.centroids"))
if(IsParameterEnabled("classifier.sharkkm.centroids") && HasValue("classifier.sharkkm.centroids"))
{
shark::Data<shark::RealVector> centroidData;
shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroidstats"), ' ');
if( HasValue( "classifier.sharkkm.centroids" ) )
shark::importCSV(centroidData, GetParameterString( "classifier.sharkkm.centroids"), ' ');
if( HasValue( "classifier.sharkkm.centroidstats" ) )
{
auto statisticsReader = otb::StatisticsXMLFileReader< itk::VariableLengthVector<float> >::New();
statisticsReader->SetFileName(GetParameterString( "classifier.sharkkm.centroidstats" ));
......@@ -85,21 +85,18 @@ void LearningApplicationBase<TInputValue, TOutputValue>::TrainSharkKMeans(
auto stddevMeasurementVector = statisticsReader->GetStatisticVectorByName("stddev");
// Convert itk Variable Length Vector to shark Real Vector
shark::RealVector meanMeasurementRV(meanMeasurementVector.Size());
for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i)
{
// Substract the mean
meanMeasurementRV[i] = - meanMeasurementVector[i];
}
shark::RealVector offsetRV(meanMeasurementVector.Size());
shark::RealVector stddevMeasurementRV(stddevMeasurementVector.Size());
for (unsigned int i = 0; i<stddevMeasurementVector.Size(); ++i)
assert(meanMeasurementVector.Size()==stddevMeasurementVector.Size());
for (unsigned int i = 0; i<meanMeasurementVector.Size(); ++i)
{
stddevMeasurementRV[i] = stddevMeasurementVector[i];
stddevMeasurementRV[i] = 1/stddevMeasurementVector[i];
// Substract the normalized mean
offsetRV[i] = - meanMeasurementVector[i]/stddevMeasurementVector[i];
}
shark::Normalizer<> normalizer(stddevMeasurementRV, meanMeasurementRV);
shark::Normalizer<> normalizer(stddevMeasurementRV, offsetRV);
centroidData = normalizer(centroidData);
}
......
......@@ -129,10 +129,6 @@ public:
itkGetMacro( Normalized, bool );
itkSetMacro( Normalized, bool );
/** If true, normalized input data sample list */
itkGetMacro( CentroidFilename, std::string );
itkSetMacro( CentroidFilename, std::string );
/** Initialize the centroids for the kmeans algorithm */
void SetCentroidsFromData(const shark::Data<shark::RealVector> & data)
{
......@@ -167,22 +163,15 @@ private:
SharkKMeansMachineLearningModel(const Self &) = delete;
void operator=(const Self &) = delete;
bool InitializeCentroids();
// Parameters set by the user
bool m_Normalized;
unsigned int m_K;
unsigned int m_MaximumNumberOfIterations;
bool m_CanRead;
/** Centroids results form kMeans */
shark::Centroids m_Centroids;
/** Input centroid filename */
std::string m_CentroidFilename;
/** shark Model could be SoftClusteringModel or HardClusteringModel */
boost::shared_ptr<ClusteringModelType> m_ClusteringModel;
......