Skip to content
Snippets Groups Projects
Commit eb262103 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: expose tree types for regression

parent 9cc50b18
No related branches found
No related tags found
No related merge requests found
......@@ -33,8 +33,15 @@ LearningApplicationBase<TInputValue,TOutputValue>
"classifier.gbt",
"This group of parameters allows to set Gradient Boosted Tree classifier parameters. "
"See complete documentation here \\url{http://docs.opencv.org/modules/ml/doc/gradient_boosted_trees.html}.");
//LossFunctionType : not exposed, as only one type is used for Classification,
// the other three are used for regression.
if (m_RegressionFlag)
{
AddParameter(ParameterType_Choice, "classifier.gbt.t", "Loss Function Type");
SetParameterDescription("classifier.gbt.t","Type of loss functionused for training.");
AddChoice("classifier.gbt.t.sqr","Squared Loss");
AddChoice("classifier.gbt.t.abs","Absolute Loss");
AddChoice("classifier.gbt.t.hub","Huber Loss");
}
//WeakCount
AddParameter(ParameterType_Int, "classifier.gbt.w", "Number of boosting algorithm iterations");
......@@ -87,6 +94,29 @@ LearningApplicationBase<TInputValue,TOutputValue>
classifier->SetSubSamplePortion(GetParameterFloat("classifier.gbt.p"));
classifier->SetMaxDepth(GetParameterInt("classifier.gbt.max"));
if (m_RegressionFlag)
{
switch (GetParameterInt("classifier.gbt.t"))
{
case 0: // SQUARED_LOSS
classifier->SetLossFunctionType(CvGBTrees::SQUARED_LOSS);
break;
case 1: // ABSOLUTE_LOSS
classifier->SetLossFunctionType(CvGBTrees::ABSOLUTE_LOSS);
break;
case 2: // HUBER_LOSS
classifier->SetLossFunctionType(CvGBTrees::HUBER_LOSS);
break;
default:
classifier->SetLossFunctionType(CvGBTrees::SQUARED_LOSS);
break;
}
}
else
{
classifier->SetLossFunctionType(CvGBTrees::DEVIANCE_LOSS);
}
classifier->Train();
classifier->Save(modelPath);
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment