LibSVM C/C++
阅读原文时间:2023年07月12日阅读:2

本系列文章由 @YhL_Leo 出品,转载请注明出处。

文章链接: http://blog.csdn.net/yhl_leo/article/details/50179779


LibSVM的库的svm.h头文件中定义了四个主要结构体:

1 训练模型的结构体

struct svm_problem
{
    int l;                // total number of samples
    double *y;            // label of each sample
    struct svm_node **x;  // feature vector of each sample
};

样本的类别通常使用+1-1进行标识。如果样本的类别,则分类的准确率也就无法计算。

2 数据节点的结构体

struct svm_node
{
    int index;
    double value;
};

数据组织结构如图1所示:

3 模型参数结构体

struct svm_parameter
{
    int svm_type;
    int kernel_type;
    int degree; /* for poly */
    double gamma;   /* for poly/rbf/sigmoid */
    double coef0;   /* for poly/sigmoid */

    /* these are for training only */
    double cache_size; /* in MB */
    double eps; /* stopping criteria */
    double C;   /* for C_SVC, EPSILON_SVR and NU_SVR */
    int nr_weight;      /* for C_SVC */
    int *weight_label;  /* for C_SVC */
    double* weight;     /* for C_SVC */
    double nu;  /* for NU_SVC, ONE_CLASS, and NU_SVR */
    double p;   /* for EPSILON_SVR */
    int shrinking;  /* use the shrinking heuristics */
    int probability; /* do probability estimates */
};

其中,各个参数的含义为:

-s svm_type : set type of SVM (default 0)
    0 -- C-SVC
    1 -- nu-SVC
    2 -- one-class SVM
    3 -- epsilon-SVR
    4 -- nu-SVR
-t kernel_type : set type of kernel function (default 2)
    0 -- linear: u'*v
    1 -- polynomial: (gamma*u'*v + coef0)^degree
    2 -- radial basis function: exp(-gamma*|u-v|^2)
    3 -- sigmoid: tanh(gamma*u'*v + coef0)
-d degree : set degree in kernel function (default 3)
-g gamma : set gamma in kernel function (default 1/num_features)
-r coef0 : set coef0 in kernel function (default 0)
-c cost : set the parameter C of C-SVC, epsilon-SVR, and nu-SVR (default 1)
-n nu : set the parameter nu of nu-SVC, one-class SVM, and nu-SVR (default 0.5)
-p epsilon : set the epsilon in loss function of epsilon-SVR (default 0.1)
-m cachesize : set cache memory size in MB (default 100)
-e epsilon : set tolerance of termination criterion (default 0.001)
-h shrinking: whether to use the shrinking heuristics, 0 or 1 (default 1)
-b probability_estimates: whether to train a SVC or SVR model for probability estimates, 0 or 1 (default 0)
-wi weight: set the parameter C of class i to weight*C, for C-SVC (default 1)

SVM模型类型和核函数类型:

enum { C_SVC, NU_SVC, ONE_CLASS, EPSILON_SVR, NU_SVR }; /* svm_type */
enum { LINEAR, POLY, RBF, SIGMOID, PRECOMPUTED }; /* kernel_type */

4 训练输出模型结构体

struct svm_model
{
    struct svm_parameter param; /* parameter */
    int nr_class;       /* number of classes, = 2 in regression/one class svm */
    int l;          /* total #SV */
    struct svm_node **SV;       /* SVs (SV[l]) */
    double **sv_coef;   /* coefficients for SVs in decision functions (sv_coef[k-1][l]) */
    double *rho;        /* constants in decision functions (rho[k*(k-1)/2]) */
    double *probA;      /* pariwise probability information */
    double *probB;
    int *sv_indices;        /* sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set */

    /* for classification only */

    int *label;     /* label of each class (label[k]) */
    int *nSV;       /* number of SVs for each class (nSV[k]) */
                /* nSV[0] + nSV[1] + ... + nSV[k-1] = l */
    /* XXX */
    int free_sv;        /* 1 if svm_model is created by svm_load_model*/
                /* 0 if svm_model is created by svm_train */
};

5 使用方法

LibSVM提供的样本特征集heart_scale为例,首先需要读取样本特征数据,可以利用svm-train.c文件中的read_problem函数,为了方便使用,对其进行了重写改写:

// TrainingDataLoad.h
/*
    Load training data from svm format file.

    - Editor: Yahui Liu.
    - Data:   2015-11-30
    - Email:  yahui.cvrs@gmail.com
    - Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/

#ifndef TRAINING_DATA_LOAD_H
#define TRAINING_DATA_LOAD_H
#pragma once

#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>
#include <iostream>
#include <vector>
#include <string>
#include <fstream>
#include <errno.h>

#include "svm.h"
//#include "svm-scale.c"

using namespace std;

#define MAX_LINE_LEN 1024

class TrainingDateLoad
{
public:
    TrainingDateLoad()
    {
        line = NULL;
    }

    ~TrainingDateLoad()
    {
        line = NULL;
    }

public:
    char* line;

// public:
//  static struct svm_parameter _paramInit;

public:

    /*! load svm model */
    void loadModel( std::string filename,  struct svm_model*& model);
    /*! skip the target */
    void svmSkipTarget( char*& p);
    /* skip the element */
    void svmSkipElement( char*& p);

    void initialParams( struct svm_parameter& param );
    /*! load training data */
    void readProblem( std::string filename, struct svm_problem& prob, struct svm_parameter& param );

    char* readline(FILE *input); 

    void exit_input_error(int line_num)
    {
        cout << "Wrong input format at line: " << line_num << endl;
        exit(1);
    }

};

#endif // TRAINING_DATA_LOAD_H

// TrainingDataLoad.cpp
#include "TrainingDataLoad.h"

void TrainingDateLoad::loadModel(std::string filename, struct svm_model*& model)
{
    model = svm_load_model(filename.c_str());
}

void TrainingDateLoad::svmSkipTarget(char*& p)
{
    while(isspace(*p)) ++p;

    while(!isspace(*p)) ++p;
}

void TrainingDateLoad::svmSkipElement(char*& p)
{
    while(*p!=':') ++p;

    ++p;
    while(isspace(*p)) ++p;
    while(*p && !isspace(*p)) ++p;
}

void TrainingDateLoad::initialParams( struct svm_parameter& param )
{
    // default values
    param.svm_type = C_SVC;
    param.kernel_type = RBF;
    param.degree = 3;
    param.gamma = 0;    // 1/num_features
    param.coef0 = 0;
    param.nu = 0.5;
    param.cache_size = 100;
    param.C = 1;
    param.eps = 1e-3;
    param.p = 0.1;
    param.shrinking = 1;
    param.probability = 0;
    param.nr_weight = 0;
    param.weight_label = NULL;
    param.weight = NULL;
}

void TrainingDateLoad::readProblem( std::string filename,
    struct svm_problem& prob, struct svm_parameter& param )
{
    int max_index, inst_max_index, i;
    size_t elements, j;
    FILE *fp = fopen(filename.c_str(),"r");
    char *endptr;
    char *idx, *val, *label;

    if(fp == NULL)
    {
        fprintf(stderr,"can't open input file %s\n",filename);
        exit(1);
    }

    prob.l = 0;
    elements = 0;

    line = new char[MAX_LINE_LEN];
    while(readline(fp)!=NULL)
    {
        char *p = strtok(line," \t"); // label

        // features
        while(1)
        {
            p = strtok(NULL," \t");
            if(p == NULL || *p == '\n') // check '\n' as ' ' may be after the last feature
                break;
            ++elements;
        }
        ++elements;
        ++prob.l;
    }
    rewind(fp);

    prob.y = new double[prob.l];
    prob.x = new struct svm_node *[prob.l];
    struct svm_node *x_space = new struct svm_node[elements];

    max_index = 0;
    j=0;
    for(i=0;i<prob.l;i++)
    {
        inst_max_index = -1; // strtol gives 0 if wrong format, and precomputed kernel has <index> start from 0
        readline(fp);
        prob.x[i] = &x_space[j];
        label = strtok(line," \t\n");
        if(label == NULL) // empty line
            exit_input_error(i+1);

        prob.y[i] = strtod(label,&endptr);
        if(endptr == label || *endptr != '\0')
            exit_input_error(i+1);

        while(1)
        {
            idx = strtok(NULL,":");
            val = strtok(NULL," \t");

            if(val == NULL)
                break;

            errno = 0;
            x_space[j].index = (int) strtol(idx,&endptr,10);
            if(endptr == idx || errno != 0 || *endptr != '\0' || x_space[j].index <= inst_max_index)
                exit_input_error(i+1);
            else
                inst_max_index = x_space[j].index;

            errno = 0;
            x_space[j].value = strtod(val,&endptr);
            if(endptr == val || errno != 0 || (*endptr != '\0' && !isspace(*endptr)))
                exit_input_error(i+1);

            ++j;
        }

        if(inst_max_index > max_index)
            max_index = inst_max_index;
        x_space[j++].index = -1;
    }

    if(param.gamma == 0 && max_index > 0)
        param.gamma = 1.0/max_index;

    if(param.kernel_type == PRECOMPUTED)
        for(i=0;i<prob.l;i++)
        {
            if (prob.x[i][0].index != 0)
            {
                fprintf(stderr,"Wrong input format: first column must be 0:sample_serial_number\n");
                exit(1);
            }
            if ((int)prob.x[i][0].value <= 0 || (int)prob.x[i][0].value > max_index)
            {
                fprintf(stderr,"Wrong input format: sample_serial_number out of range\n");
                exit(1);
            }
        }

        fclose(fp);
}

char* TrainingDateLoad::readline(FILE *input)
{
    int len;
    if(fgets(line,MAX_LINE_LEN,input) == NULL)
        return NULL;

    int max_line_len = MAX_LINE_LEN;
    while(strrchr(line,'\n') == NULL)
    {
        max_line_len *= 2;
        line = (char *) realloc(line,max_line_len);
        len = (int) strlen(line);
        if(fgets(line+len,max_line_len-len,input) == NULL)
            break;
    }
    return line;
}

将样本训练与预测进行改写:

// LibSVMTools.h
/*
    LibSVM train and predict tools.

    - Editor: Yahui Liu.
    - Data:   2015-12-3
    - Email:  yahui.cvrs@gmail.com
    - Address: Computer Vision and Remote Sensing(CVRS), Lab.
**/

#ifndef LIBSVM_TOOL_H
#define LIBSVM_TOOL_H
#pragma once

#include <iostream>
#include <string>

#include "svm.h"
#include "TrainingDataLoad.h"

class LibSVMTools
{
public:
    LibSVMTools(){}
    ~LibSVMTools(){}

public:
    /*!
        - featureFile: features of images saved in libsvm format.
        - saveModelFile: save the trained model file.
    **/
    void libSvmTrain(std::string featureFile, std::string saveModelFile);

    /*!
        - featureFile: features of images saved in libsvm format.
        - modelFile: libsvm trained model.
        - savePredictFile: save the predicting results.
    **/
    void libSvmPredict(std::string featureFile, std::string modelFile, std::string savePredictFile);
};

#endif // LIBSVM_TOOL_H

// LibSVMTools.cpp
#include "LibSVMTools.h"

void LibSVMTools::libSvmTrain(std::string featureFile, std::string saveModelFile)
{
    struct svm_parameter param;
    struct svm_problem prob;

    TrainingDateLoad* trainData = new TrainingDateLoad;
    trainData->initialParams( param );
    trainData->readProblem(featureFile, prob, param);

    const char*errorMsg = svm_check_parameter(&prob, &param);
    if ( errorMsg )
    {
        cout << errorMsg << endl;
        return;
    }

    struct svm_model *model = svm_train(&prob, &param);

#if 1
    cout << "svm_type: " << model->param.svm_type << endl <<
        "kernel_type: " << model->param.kernel_type << endl <<
        "gamma: " << model->param.gamma << endl <<
        "nr_class: " << model->nr_class << endl <<
        "total_sv: " << model->l << endl <<
        "rho: " << model->rho[0] << endl <<
        "label: " << model->label[0] << " " << model->label[1] << endl <<
        "nr_sv: " << model->nSV[0] << " " << model->nSV[1] << endl;
#endif

    int saveModel = svm_save_model( saveModelFile.c_str(), model );
}

void LibSVMTools::libSvmPredict(std::string featureFile,
    std::string modelFile, std::string savePredictFile)
{
    struct svm_parameter param;
    struct svm_problem prob;

    TrainingDateLoad * trainData = new TrainingDateLoad;
    trainData->initialParams( param );
    trainData->readProblem(featureFile, prob, param);

    struct svm_model* model;
    trainData->loadModel(modelFile.c_str(), model);

    float correct(0.0);     // all correct
    float uncorrect_1(0.0); // pos to neg
    float uncorrect_2(0.0); // neg to pos
    if ( prob.l )
    {
        const int nCount = prob.l;;

        ofstream outfile( savePredictFile, ios::out );
        for( int i=0; i<nCount; i++ )
        {
            double label = svm_predict(model, prob.x[i]);
            if ( label == prob.y[i] )
            {
                correct ++;
            }
            else if ( label == -1.0 )
            {
                uncorrect_1 ++;
            }
            else
            {
                uncorrect_2 ++;
            }
            outfile << label << endl;
        }
#if 1
        cout << "total data count: " << nCount << endl <<
            "classification correct: " << correct << endl <<
            "pos to neg count: " << uncorrect_1 << endl <<
            "neg to pos count: " << uncorrect_2 << endl;

        cout << "Accuracy: " << static_cast<float>(correct/nCount)
            << "(" << correct << "/" << nCount << ")" << endl;
#endif
        outfile.close();
    }
}

用例Demo:

// train
#include "LibSVMTools.h"

void main()
{
    std::cout <<
        "************************************************************" << endl <<
        "**          PROGRAM: LibSVM model training.               **" << endl <<
        "**                                                        **" << endl <<
        "**           Author: Yahui Liu.                           **" << endl <<
        "**                   School of Remote Sensing & Inf. Eng. **" << endl <<
        "**                   Wuhan University, Hubei, P.R. China  **" << endl <<
        "**            Email: yahui.cvrs@gmail.com                 **" << endl <<
        "**      Create time: Dec. 1, 2015                         **" << endl <<
        "************************************************************" << endl;

    string filename = "..\\..\\..\\Data\\heat_scale";
    std::string savefielname = "..\\..\\..\\Data\\train.model";

    LibSVMTools* libsvm = new LibSVMTools();
    libsvm->libSvmTrain(filename, savefielname);

    delete libsvm;
}

/*------------------------------------------------------------------------------------*/

// predict
#include "LibSVMTools.h"

void main()
{
    std::cout <<
        "************************************************************" << endl <<
        "**          PROGRAM: LibSVM predict.                      **" << endl <<
        "**                                                        **" << endl <<
        "**           Author: Yahui Liu.                           **" << endl <<
        "**                   School of Remote Sensing & Inf. Eng. **" << endl <<
        "**                   Wuhan University, Hubei, P.R. China  **" << endl <<
        "**            Email: yahui.cvrs@gmail.com                 **" << endl <<
        "**      Create time: Dec. 1, 2015                         **" << endl <<
        "************************************************************" << endl;

    std::string featureFile = "..\\..\\..\\Data\\heart_scale";
    std::string modelFile = "..\\..\\..\\Data\\train.model";
    std::string savePredictFile = "..\\..\\..\\Data\\predict.out";

    LibSVMTools* libsvm = new LibSVMTools();
    libsvm->libSvmPredict(featureFile, modelFile, savePredictFile);

    delete libsvm;
}

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章