Skip to content
Snippets Groups Projects
Commit 91254046 authored by Guillaume Pasero's avatar Guillaume Pasero
Browse files

ENH: implement regression in GBTree ML model

parent 62fe025b
Branches
Tags
No related merge requests found
......@@ -130,6 +130,9 @@ protected:
/** Predict values using the model */
virtual TargetSampleType PredictClassification(const InputSampleType& input, ConfidenceValueType *quality=NULL) const;
virtual void TrainRegression();
virtual TargetSampleType PredictRegression(const InputSampleType& input) const;
private:
GradientBoostedTreeMachineLearningModel(const Self &); //purposely not implemented
void operator =(const Self&); //purposely not implemented
......@@ -142,7 +145,6 @@ private:
double m_SubSamplePortion;
int m_MaxDepth;
bool m_UseSurrogates;
bool m_IsRegression;
};
......
......@@ -36,8 +36,7 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
m_Shrinkage(0.01),
m_SubSamplePortion(0.8),
m_MaxDepth(3),
m_UseSurrogates(false),
m_IsRegression(false)
m_UseSurrogates(false)
{
}
......@@ -62,7 +61,6 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
cv::Mat labels;
otb::ListSampleToMat<TargetListSampleType>(this->GetTargetListSample(),labels);
CvGBTreesParams params = CvGBTreesParams(m_LossFunctionType, m_WeakCount, m_Shrinkage, m_SubSamplePortion,
m_MaxDepth, m_UseSurrogates);
......@@ -70,12 +68,20 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
cv::Mat var_type = cv::Mat(this->GetInputListSample()->GetMeasurementVectorSize() + 1, 1, CV_8U );
var_type.setTo(cv::Scalar(CV_VAR_NUMERICAL) ); // all inputs are numerical
if (!m_IsRegression) //Classification
if (!this->m_RegressionMode) //Classification
var_type.at<uchar>(this->GetInputListSample()->GetMeasurementVectorSize(), 0) = CV_VAR_CATEGORICAL;
m_GBTreeModel->train(samples,CV_ROW_SAMPLE,labels,cv::Mat(),cv::Mat(),var_type,cv::Mat(),params, false);
}
template <class TInputValue, class TOutputValue>
void
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::TrainRegression()
{
this->TrainClassification();
}
template <class TInputValue, class TOutputValue>
typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
......@@ -104,6 +110,15 @@ GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
return target;
}
template <class TInputValue, class TOutputValue>
typename GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::TargetSampleType
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
::PredictRegression(const InputSampleType & input) const
{
return this->PredictClassification(input, NULL);
}
template <class TInputValue, class TOutputValue>
void
GradientBoostedTreeMachineLearningModel<TInputValue,TOutputValue>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment