Commit 029a6b28 authored by Guillaume Pasero's avatar Guillaume Pasero

REFAC: adapt SVMMarginSampler for LibSVMMachineLearningModel

parent 3411e23b
......@@ -251,6 +251,19 @@ public:
double CrossValidation(void);
/** Return number of support vectors */
unsigned int GetNumberOfSupportVectors(void) const
{
if (m_Model) return m_Model->l;
return 0;
}
unsigned int GetNumberOfClasses(void) const
{
if (m_Model) return m_Model->nr_class;
return 0;
}
protected:
/** Constructor */
LibSVMMachineLearningModel();
......
......@@ -69,7 +69,6 @@ public:
/** Type definitions for the SVM Model. */
typedef TModel SVMModelType;
typedef typename SVMModelType::Pointer SVMModelPointer;
typedef typename SVMModelType::DistancesVectorType DistancesVectorType;
itkSetMacro(NumberOfCandidates, unsigned int);
itkGetMacro(NumberOfCandidates, unsigned int);
......@@ -95,8 +94,6 @@ protected:
virtual void DoMarginSampling();
private:
/** Output pointer (MembershipSample) */
//typename OutputType::Pointer m_Output;
SVMModelPointer m_Model;
......
......@@ -78,28 +78,32 @@ SVMMarginSampler< TSample, TModel >
typename OutputType::ConstIterator endO = output->End();
typename TSample::MeasurementVectorType measurements;
m_Model->SetConfidenceMode(TModel::CM_HYPER);
int numberOfComponentsPerSample = iter.GetMeasurementVector().Size();
int nbClass = static_cast<int>(m_Model->GetNumberOfClasses());
std::vector<double> hdistances(nbClass * (nbClass - 1) / 2);
otbMsgDevMacro( << "Starting iterations " );
while (iter != end && iterO != endO)
{
int i = 0;
typename SVMModelType::MeasurementType modelMeasurement;
typename SVMModelType::InputSampleType modelMeasurement;
measurements = iter.GetMeasurementVector();
// otbMsgDevMacro( << "Loop on components " << svm_type );
for (i=0; i<numberOfComponentsPerSample; ++i)
{
modelMeasurement.push_back(measurements[i]);
modelMeasurement.PushBack(measurements[i]);
}
// Get distances to the hyperplanes
DistancesVectorType hdistances = m_Model->EvaluateHyperplanesDistances(modelMeasurement);
m_Model->Predict(modelMeasurement, &(hdistances[0]));
double minDistance = vcl_abs(hdistances[0]);
// Compute th min distances
for(unsigned int j = 1; j<hdistances.Size(); ++j)
for(unsigned int j = 1; j<hdistances.size(); ++j)
{
if(vcl_abs(hdistances[j])<minDistance)
{
......@@ -129,8 +133,6 @@ SVMMarginSampler< TSample, TModel >
m_MarginSamples.push_back(idDistVector[i].first);
}
// m_Output->AddInstance(static_cast<unsigned int>(classLabel), iterO.GetInstanceIdentifier());
}
} // end of namespace otb
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment