本系列文章由 @YhL_Leo 出品,转载请注明出处。
文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779
在LibSVM
的库的svm.h
头文件中定义了四个主要结构体:
1 训练模型的结构体
struct svm_problem
{
int l; // total number of samples
double *y; // label of each sample
struct svm_node **x; // feature vector of each sample
};
样本的类别通常使用+1
与-1
进行标识。如果样本的类别,则分类的准确率也就无法计算。
2 数据节点的结构体
struct svm_node
{
int index;
double value;
};
数据组织结构如图1所示:
3 模型参数结构体
struct svm_parameter
{
int svm_type;
int kernel_type;
int degree; /* for poly */
double gamma; /* for poly/rbf/sigmoid */
double coef0; /* for poly/sigmoid */
/* these are for training only */
double cache_size; /* in MB */
double eps; /* stopping criteria */
double C; /* for C_SVC, EPSILON_SVR and NU_SVR */
int nr_weight; /* for C_SVC */
int *weight_label; /* for C_SVC */
double* weight; /* for C_SVC */
double nu; /* for NU_SVC, ONE_CLASS, and NU_SVR */
double p; /* for EPSILON_SVR */
int shrinking; /* use the shrinking heuristics */
int probability; /* do probability estimates */
};
其中,各个参数的含义为:
-s svm_type : set type of SVM (default 0)
0 -- C-SVC
1 -- nu-SVC
2 -- one-class SVM
3 -- epsilon-SVR
4 -- nu-SVR
-t kernel_type : set type of kernel function (default 2)
0 -- linear: u'*v
1 -- polynomial: (gamma*u'*v + coef0)^degree
2 -- radial basis function: exp(-gamma*|u-v|^2)
3 -- sigmoid: tanh(gamma*u'*v + coef0)
-d degree : set degree in kernel function (default 3)
-g gamma : set gamma in kernel function (default 1/num_features)
-r coef0 : set coef0 in kernel function (default 0)
-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
-m cachesize : set cache memory size in MB (default 100)
-e epsilon : set tolerance of termination criterion (default 0.001)
-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)
SVM模型类型和核函数类型:
enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */
enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */
4 训练输出模型结构体
struct svm_model
{
struct svm_parameter param; /* parameter */
int nr_class; /* number of classes, = 2 in regression/one class svm */
int l; /* total #SV */
struct svm_node **SV; /* SVs (SV[l]) */
double **sv_coef; /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
double *rho; /* constants in decision functions (rho[k*(k-1)/2]) */
double *probA; /* pariwise probability information */
double *probB;
int *sv_indices; /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */
/* for classification only */
int *label; /* label of each class (label[k]) */
int *nSV; /* number of SVs for each class (nSV[k]) */
/* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
/* XXX */
int free_sv; /* 1 if svm_model is created by svm_load_model*/
/* 0 if svm_model is created by svm_train */
};
5 使用方法
以LibSVM
提供的样本特征集heart_scale
为例,首先需要读取样本特征数据,可以利用svm-train.c
文件中的read_problem
函数,为了方便使用,对其进行了重写改写:
// TrainingDataLoad.h
/*
Load training data from svm format file.
- Editor: Yahui Liu.
- Data: 2015-11-30
- Email: yahui.cvrs@gmail.com
- Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/
#ifndef TRAINING_DATA_LOAD_H
#define TRAINING_DATA_LOAD_H
#pragma once
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <errno.h>
#include "svm.h"
//#include "svm-scale.c"
using namespace std;
#define MAX_LINE_LEN 1024
class TrainingDateLoad
{
public:
TrainingDateLoad()
{
line = NULL;
}
~TrainingDateLoad()
{
line = NULL;
}
public:
char* line;
// public:
// static struct svm_parameter _paramInit;
public:
/*! load svm model */
void loadModel( std::string filename, struct svm_model*& model);
/*! skip the target */
void svmSkipTarget( char*& p);
/* skip the element */
void svmSkipElement( char*& p);
void initialParams( struct svm_parameter& param );
/*! load training data */
void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param );
char* readline(FILE *input);
void exit_input_error(int line_num)
{
cout << "Wrong input format at line: " << line_num << endl;
exit(1);
}
};
#endif // TRAINING_DATA_LOAD_H
// TrainingDataLoad.cpp
#include "TrainingDataLoad.h"
void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model)
{
model = svm_load_model(filename.c_str());
}
void TrainingDateLoad::svmSkipTarget(char*& p)
{
while(isspace(*p)) ++p;
while(!isspace(*p)) ++p;
}
void TrainingDateLoad::svmSkipElement(char*& p)
{
while(*p!=':') ++p;
++p;
while(isspace(*p)) ++p;
while(*p && !isspace(*p)) ++p;
}
void TrainingDateLoad::initialParams( struct svm_parameter& param )
{
// default values
param.svm_type = C_SVC;
param.kernel_type = RBF;
param.degree = 3;
param.gamma = 0; // 1/num_features
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 100;
param.C = 1;
param.eps = 1e-3;
param.p = 0.1;
param.shrinking = 1;
param.probability = 0;
param.nr_weight = 0;
param.weight_label = NULL;
param.weight = NULL;
}
void TrainingDateLoad::readProblem( std::string filename,
struct svm_problem& prob, struct svm_parameter& param )
{
int max_index, inst_max_index, i;
size_t elements, j;
FILE *fp = fopen(filename.c_str(),"r");
char *endptr;
char *idx, *val, *label;
if(fp == NULL)
{
fprintf(stderr,"can't open input file %s\n",filename);
exit(1);
}
prob.l = 0;
elements = 0;
line = new char[MAX_LINE_LEN];
while(readline(fp)!=NULL)
{
char *p = strtok(line," \t"); // label
// features
while(1)
{
p = strtok(NULL," \t");
if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
break;
++elements;
}
++elements;
++prob.l;
}
rewind(fp);
prob.y = new double[prob.l];
prob.x = new struct svm_node *[prob.l];
struct svm_node *x_space = new struct svm_node[elements];
max_index = 0;
j=0;
for(i=0;i<prob.l;i++)
{
inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
readline(fp);
prob.x[i] = &x_space[j];
label = strtok(line," \t\n");
if(label == NULL) // empty line
exit_input_error(i+1);
prob.y[i] = strtod(label,&endptr);
if(endptr == label || *endptr != '\0')
exit_input_error(i+1);
while(1)
{
idx = strtok(NULL,":");
val = strtok(NULL," \t");
if(val == NULL)
break;
errno = 0;
x_space[j].index = (int) strtol(idx,&endptr,10);
if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
exit_input_error(i+1);
else
inst_max_index = x_space[j].index;
errno = 0;
x_space[j].value = strtod(val,&endptr);
if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
exit_input_error(i+1);
++j;
}
if(inst_max_index > max_index)
max_index = inst_max_index;
x_space[j++].index = -1;
}
if(param.gamma == 0 && max_index > 0)
param.gamma = 1.0/max_index;
if(param.kernel_type == PRECOMPUTED)
for(i=0;i<prob.l;i++)
{
if (prob.x[i][0].index != 0)
{
fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
exit(1);
}
if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
{
fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
exit(1);
}
}
fclose(fp);
}
char* TrainingDateLoad::readline(FILE *input)
{
int len;
if(fgets(line,MAX_LINE_LEN,input) == NULL)
return NULL;
int max_line_len = MAX_LINE_LEN;
while(strrchr(line,'\n') == NULL)
{
max_line_len *= 2;
line = (char *) realloc(line,max_line_len);
len = (int) strlen(line);
if(fgets(line+len,max_line_len-len,input) == NULL)
break;
}
return line;
}
将样本训练与预测进行改写:
// LibSVMTools.h
/*
LibSVM train and predict tools.
- Editor: Yahui Liu.
- Data: 2015-12-3
- Email: yahui.cvrs@gmail.com
- Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/
#ifndef LIBSVM_TOOL_H
#define LIBSVM_TOOL_H
#pragma once
#include <iostream>
#include <string>
#include "svm.h"
#include "TrainingDataLoad.h"
class LibSVMTools
{
public:
LibSVMTools(){}
~LibSVMTools(){}
public:
/*!
- featureFile: features of images saved in libsvm format.
- saveModelFile: save the trained model file.
**/
void libSvmTrain(std::string featureFile, std::string saveModelFile);
/*!
- featureFile: features of images saved in libsvm format.
- modelFile: libsvm trained model.
- savePredictFile: save the predicting results.
**/
void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);
};
#endif // LIBSVM_TOOL_H
// LibSVMTools.cpp
#include "LibSVMTools.h"
void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile)
{
struct svm_parameter param;
struct svm_problem prob;
TrainingDateLoad* trainData = new TrainingDateLoad;
trainData->initialParams( param );
trainData->readProblem(featureFile, prob, param);
const char*errorMsg = svm_check_parameter(&prob, ¶m);
if ( errorMsg )
{
cout << errorMsg << endl;
return;
}
struct svm_model *model = svm_train(&prob, ¶m);
#if 1
cout << "svm_type: " << model->param.svm_type << endl <<
"kernel_type: " << model->param.kernel_type << endl <<
"gamma: " << model->param.gamma << endl <<
"nr_class: " << model->nr_class << endl <<
"total_sv: " << model->l << endl <<
"rho: " << model->rho[0] << endl <<
"label: " << model->label[0] << " " << model->label[1] << endl <<
"nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;
#endif
int saveModel = svm_save_model( saveModelFile.c_str(), model );
}
void LibSVMTools::libSvmPredict(std::string featureFile,
std::string modelFile, std::string savePredictFile)
{
struct svm_parameter param;
struct svm_problem prob;
TrainingDateLoad * trainData = new TrainingDateLoad;
trainData->initialParams( param );
trainData->readProblem(featureFile, prob, param);
struct svm_model* model;
trainData->loadModel(modelFile.c_str(), model);
float correct(0.0); // all correct
float uncorrect_1(0.0); // pos to neg
float uncorrect_2(0.0); // neg to pos
if ( prob.l )
{
const int nCount = prob.l;;
ofstream outfile( savePredictFile, ios::out );
for( int i=0; i<nCount; i++ )
{
double label = svm_predict(model, prob.x[i]);
if ( label == prob.y[i] )
{
correct ++;
}
else if ( label == -1.0 )
{
uncorrect_1 ++;
}
else
{
uncorrect_2 ++;
}
outfile << label << endl;
}
#if 1
cout << "total data count: " << nCount << endl <<
"classification correct: " << correct << endl <<
"pos to neg count: " << uncorrect_1 << endl <<
"neg to pos count: " << uncorrect_2 << endl;
cout << "Accuracy: " << static_cast<float>(correct/nCount)
<< "(" << correct << "/" << nCount << ")" << endl;
#endif
outfile.close();
}
}
用例Demo:
// train
#include "LibSVMTools.h"
void main()
{
std::cout <<
"************************************************************" << endl <<
"** PROGRAM: LibSVM model training. **" << endl <<
"** **" << endl <<
"** Author: Yahui Liu. **" << endl <<
"** School of Remote Sensing & Inf. Eng. **" << endl <<
"** Wuhan University, Hubei, P.R. China **" << endl <<
"** Email: yahui.cvrs@gmail.com **" << endl <<
"** Create time: Dec. 1, 2015 **" << endl <<
"************************************************************" << endl;
string filename = "..\\..\\..\\Data\\heat_scale";
std::string savefielname = "..\\..\\..\\Data\\train.model";
LibSVMTools* libsvm = new LibSVMTools();
libsvm->libSvmTrain(filename, savefielname);
delete libsvm;
}
/*------------------------------------------------------------------------------------*/
// predict
#include "LibSVMTools.h"
void main()
{
std::cout <<
"************************************************************" << endl <<
"** PROGRAM: LibSVM predict. **" << endl <<
"** **" << endl <<
"** Author: Yahui Liu. **" << endl <<
"** School of Remote Sensing & Inf. Eng. **" << endl <<
"** Wuhan University, Hubei, P.R. China **" << endl <<
"** Email: yahui.cvrs@gmail.com **" << endl <<
"** Create time: Dec. 1, 2015 **" << endl <<
"************************************************************" << endl;
std::string featureFile = "..\\..\\..\\Data\\heart_scale";
std::string modelFile = "..\\..\\..\\Data\\train.model";
std::string savePredictFile = "..\\..\\..\\Data\\predict.out";
LibSVMTools* libsvm = new LibSVMTools();
libsvm->libSvmPredict(featureFile, modelFile, savePredictFile);
delete libsvm;
}
手机扫一扫
移动阅读更方便
你可能感兴趣的文章