Commit fe269fe3 authored by Manuel Grizonnet's avatar Manuel Grizonnet

TEST: add test to use SVM machine learning class in regression mode

parent 376a8ffd
......@@ -776,6 +776,13 @@ add_test(leTvListSampleGenerator1 ${LEARNING_TESTS4}
${TEMP}/svm_model.txt
)
add_test(leTvSVMMachineLearningRegressionModel ${LEARNING_TESTS6}
otbSVMMachineLearningRegressionModel
${INPUTDATA}/abalone.scale
${TEMP}/svm_model_regression.txt
)
add_test(leTuLibSVMMachineLearningModelNew ${LEARNING_TESTS6}
otbLibSVMMachineLearningModelNew)
......
......@@ -28,6 +28,7 @@ void RegisterTests()
REGISTER_TEST(otbLibSVMMachineLearningModel);
REGISTER_TEST(otbSVMMachineLearningModelNew);
REGISTER_TEST(otbSVMMachineLearningModel);
REGISTER_TEST(otbSVMMachineLearningRegressionModel);
REGISTER_TEST(otbKNearestNeighborsMachineLearningModelNew);
REGISTER_TEST(otbKNearestNeighborsMachineLearningModel);
REGISTER_TEST(otbRandomForestsMachineLearningModelNew);
......
......@@ -41,6 +41,15 @@ typedef MachineLearningModelType::InputListSampleType InputListSampleType;
typedef MachineLearningModelType::TargetValueType TargetValueType;
typedef MachineLearningModelType::TargetSampleType TargetSampleType;
typedef MachineLearningModelType::TargetListSampleType TargetListSampleType;
typedef otb::MachineLearningModel<float,float> MachineLearningModelRegressionType;
typedef MachineLearningModelRegressionType::InputValueType InputValueRegressionType;
typedef MachineLearningModelRegressionType::InputSampleType InputSampleRegressionType;
typedef MachineLearningModelRegressionType::InputListSampleType InputListSampleRegressionType;
typedef MachineLearningModelRegressionType::TargetValueType TargetValueRegressionType;
typedef MachineLearningModelRegressionType::TargetSampleType TargetSampleRegressionType;
typedef MachineLearningModelRegressionType::TargetListSampleType TargetListSampleRegressionType;
typedef otb::ConfusionMatrixCalculator<TargetListSampleType, TargetListSampleType> ConfusionMatrixCalculatorType;
bool ReadDataFile(const std::string & infname, InputListSampleType * samples, TargetListSampleType * labels)
......@@ -111,6 +120,74 @@ bool ReadDataFile(const std::string & infname, InputListSampleType * samples, Ta
return true;
}
bool ReadDataRegressionFile(const std::string & infname, InputListSampleRegressionType * samples, TargetListSampleRegressionType * labels)
{
std::ifstream ifs;
ifs.open(infname.c_str());
if(!ifs)
{
std::cout<<"Could not read file "<<infname<<std::endl;
return false;
}
unsigned int nbfeatures = 0;
while (!ifs.eof())
{
std::string line;
std::getline(ifs, line);
if(nbfeatures == 0)
{
nbfeatures = std::count(line.begin(),line.end(),' ')-1;
//std::cout<<"Found "<<nbfeatures<<" features per samples"<<std::endl;
}
if(line.size()>1)
{
InputSampleRegressionType sample(nbfeatures);
sample.Fill(0);
std::string::size_type pos = line.find_first_of(" ", 0);
// Parse label
TargetSampleRegressionType label;
label[0] = atof(line.substr(0, pos).c_str());
bool endOfLine = false;
unsigned int id = 0;
while(!endOfLine)
{
std::string::size_type nextpos = line.find_first_of(" ", pos+1);
if(nextpos == std::string::npos)
{
endOfLine = true;
nextpos = line.size()-1;
}
else
{
std::string feature = line.substr(pos,nextpos-pos);
std::string::size_type semicolonpos = feature.find_first_of(":");
id = atoi(feature.substr(0,semicolonpos).c_str());
sample[id - 1] = atof(feature.substr(semicolonpos+1,feature.size()-semicolonpos).c_str());
pos = nextpos;
}
}
samples->SetMeasurementVectorSize(itk::NumericTraits<InputSampleRegressionType>::GetLength(sample));
samples->PushBack(sample);
labels->PushBack(label);
}
}
//std::cout<<"Retrieved "<<samples->Size()<<" samples"<<std::endl;
ifs.close();
return true;
}
int otbLibSVMMachineLearningModelNew(int itkNotUsed(argc), char * itkNotUsed(argv) [])
{
typedef otb::LibSVMMachineLearningModel<InputValueType, TargetValueType> SVMType;
......@@ -275,6 +352,85 @@ int otbSVMMachineLearningModel(int argc, char * argv[])
}
}
int otbSVMMachineLearningRegressionModel(int argc, char * argv[])
{
if (argc != 3 )
{
std::cout<<"Wrong number of arguments "<<std::endl;
std::cout<<"Usage : sample file, output file "<<std::endl;
return EXIT_FAILURE;
}
typedef otb::SVMMachineLearningModel<InputValueRegressionType, TargetValueRegressionType> SVMType;
InputListSampleRegressionType::Pointer samples = InputListSampleRegressionType::New();
TargetListSampleRegressionType::Pointer labels = TargetListSampleRegressionType::New();
TargetListSampleRegressionType::Pointer predicted = TargetListSampleRegressionType::New();
if(!ReadDataRegressionFile(argv[1],samples,labels))
{
std::cout<<"Failed to read samples file "<<argv[1]<<std::endl;
return EXIT_FAILURE;
}
SVMType::Pointer classifier = SVMType::New();
//Init SVM type in regression mode
//Available mode for regression in openCV are eps_svr and nu_svr
classifier->SetSVMType(CvSVM::EPS_SVR);
//classifier->SetSVMType(CvSVM::NU_SVR);
//P should be >0. Increasing value give better result. Need to investigate why.
classifier->SetP(10);
//IN case you're using nu_svr you should set nu to a positive value between 0
//and 1.
//classifier->SetNu(0.9);
//Use RBF kernel.Don't know what is recommended in case of svm regression
classifier->SetKernelType(CvSVM::RBF);
classifier->SetInputListSample(samples);
classifier->SetTargetListSample(labels);
classifier->Train();
//Predict age using first line of abalone dataset
//1:-1 2:0.027027 3:0.0420168 4:-0.831858 5:-0.63733 6:-0.699395 7:-0.735352
//8:-0.704036
// Input value is 15.
InputListSampleRegressionType::Pointer samplesT = InputListSampleRegressionType::New();
//Init sample list to 8 (size of abalone dataset)
InputSampleRegressionType sample(8);
sample.Fill(0);
sample[0] = -1;
sample[1] = 0.027027;
sample[2] = 0.0420168;
sample[3] = -0.831858;
sample[4] = -0.63733;
sample[5] = -0.699395;
sample[6] = -0.735352;
sample[7] = -0.704036;
samplesT->SetMeasurementVectorSize(itk::NumericTraits<InputSampleRegressionType>::GetLength(sample));
samplesT->PushBack(sample);
classifier->SetInputListSample(samplesT);
classifier->SetTargetListSample(predicted);
classifier->PredictAll();
const float age = 15;
if ( vcl_abs(age - predicted->GetMeasurementVector(0)[0]) <= 0.3 )
{
return EXIT_SUCCESS;
}
else
{
return EXIT_FAILURE;
}
}
int otbKNearestNeighborsMachineLearningModelNew(int itkNotUsed(argc), char * itkNotUsed(argv) [])
{
typedef otb::KNearestNeighborsMachineLearningModel<InputValueType,TargetValueType> KNearestNeighborsType;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment