CART分类与回归树的原理与实现
阅读原文时间:2021年04月22日阅读:1

算法概述

CART(Classification And Regression Tree)算法是一种决策树分类方法。

它采用一种二分递归分割的技术,分割方法采用基于最小距离的基尼指数估计函数,将当前的样本集分为两个子样本集,使得生成的的每个非叶子节点都有两个分支。因此,CART算法生成的决策树是结构简洁的二叉树。

分类树

如果目标变量是离散变量,则是classfication Tree。

分类树是使用树结构算法将数据分成离散类的方法。

回归树

如果目标是连续变量,则是Regression Tree。

CART树是二叉树,不像多叉树那样形成过多的数据碎片。

分类树两个关键点

(1)将训练样本进行递归地划分自变量空间进行建树

(2)用验证数据进行剪枝。

a.对于离散变量X(x1…xn)

分别取X变量各值的不同组合,将其分到树的左枝或右枝,并对不同组合而产生的树,进行评判,找出最佳组合。如果只有两个取值,好办,直接根据这两个值就可以划分树。取值多于两个的情况就复杂一些了,如变量年纪,其值有“少年”、“中年”、“老年”,则分别生产{少年,中年}和{老年},{上年、老年}和{中年},{中年,老年}和{少年},这三种组合,最后评判对目标区分最佳的组合。因为CART二分的特性,当训练数据具有两个以上的类别,CART需考虑将目标类别合并成两个超类别,这个过程称为双化。这里可以说一个公式,n个属性,可以分出(2^n-2)/2种情况。

b.对于连续变量X(x1…xn)

首先将值排序,分别取其两相邻值的平均值点作为分隔点,将树一分成左枝和右枝,不断扫描,进而判断最佳分割点。特征值大于分裂值就走左子树,或者就走右子树。

这里有一个问题,这次选中的分裂属性在下次还可以被选择吗?对于离散变量XD,如果XD只有两种取值,那么在这一次分裂中,根据XD分裂后,左子树中的subDataset中每个数据的XD属性一样,右子树中的subDataset中每个数据的XD属性也一样,所以在这个节点以后,XD都不起作用了,就不用考虑XD了。XD取3种,4种。。。的情况大家自己想想,不难想明白。至于连续变量XC,离散化后相当于一个可以取n个值的离散变量,按刚刚离散变量的情况分析。除非XC的取值都一样,否则这次用了XC作为分裂属性,下次还要考虑XC。

变量和最佳切分点选择原则

树的生长,总的原则是,让枝比树更纯,而度量原则是根据不纯对指标来衡量,对于分类树,则用GINI指标、Twoing指标、Order Twoing等;如果是回归树则用,最小平方残差、最小绝对残差等指标衡量

(1)GINI指标(Gini越小,数据越纯)——针对离散目标

(2)最小平方残差——针对连续目标

其思想是,让组内方差最小,对应组间方差最大,这样两组,也即树分裂的左枝和右枝差异化最大。

通过以上不纯度指标,分别计算每个变量的各种切分/组合情况,找出该变量的最佳值组合/切分点;再比较各个变量的最佳值组合/切分点,最终找出最佳变量和该变量的最佳值组合/切分点

整个树的生长是一个递归过程,直到终止条件

终止条件

(1)节点是纯结点,即所有的记录的目标变量值相同

(2)树的深度达到了预先指定的最大值

(3)混杂度的最大下降值小于一个预先指定的值

(4)节点的记录量小于预先指定的最小节点记录量

(5)一个节点中的所有记录其预测变量值相同

直观的情况,当节点包含的数据记录都属于同一个类别时就可以终止分裂了。这只是一个特例,更一般的情况我们计算χ2值来判断分类条件和类别的相关程度,当χ2很小时说明分类条件和类别是独立的,即按照该分类条件进行分类是没有道理的,此时节点停止分裂。注意这里的“分类条件”是指按照GINI_Gain最小原则得到的“分类条件”。

终止条件(3)混杂度的最大下降值小于一个预先指定的值,该枝的分化即停止。所有枝节的分化都停止后,树形模型即成。其实你也可以不使用这个终止条件,让树生长到最大,因为CART有剪枝算法。

建树过程

这里面误分类成本和先验概率是需要提前设定好的参数。这里为node标定label如果考虑一些unbalanced data,比如训练样本里有100个正样本,只有1个负样本,这样的数据就是unbalanced,就不能简单的majority归类了。上面的这个mark label的方法对不均衡数据就有一定的鲁棒性。

要注意对于每一个树结点,不管是否叶子结点,该node都要标上label,因为后面剪枝时非叶节点可能变为叶节点。

树生长完之后就是剪枝,剪枝非常重要。剪枝目的是避免决策树过拟合(Overfitting)样本。在一般的数据集中,过拟合的决策树的错误率比经过简化的决策树的错误率要高。

剪枝算法CCP(Cost-Complexity Pruning)

这一部分参考http://blog.csdn.net/u010159842/article/details/46458973

Cost-Complexity Pruning(CCP、代价复杂度)
CCP方法包含两个步骤:
1:从原始决策树T0开始生成一个子树序列{T0、T1、T2、…、Tn},其中Ti+1是从Ti总产生,Tn为根节点
2:从子树序列中,根据树的真实误差估计选择最佳决策树。

对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。

是子树中包含的叶子节点个数;

是节点t的误差代价,如果该节点被剪枝;

r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。

比如有个非叶子节点t4如图所示:

比如有个非叶子节点t4如图所示:

已知所有的数据总共有60条,则节点t4的节点误差代价为:

子树误差代价为:

以t4为根节点的子树上叶子节点有3个,最终:

找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。

剪枝过程特别重要,所以在最优决策树生成过程中占有重要地位。有研究表明,剪枝过程的重要性要比树生成过程更为重要,对于不同的划分标准生成的最大树(Maximum Tree),在剪枝之后都能够保留最重要的属性划分,差别不大。反而是剪枝方法对于最优树的生成更为关键。

好了,再来看一个例子

很明白了吧

用一幅图解释一下

29-30之间的水平线以下的几个点代表的树都满足:

但箭头所指的树的叶节点最少,所以选择这棵树作为best tree。

缺失值的处理

对于某些采样数据,可能会缺少属性值。在这种情况下,处理缺少属性值的通常做法是赋予该属性的常见值,或者属性均值。另外一种比较好的方法是为该属性的每个可能值赋予一个概率,即将该属性以概率形式赋值。例如给定Boolean属性B,已知采样数据有12个B=0和88个B=1实例,那么在赋值过程中,B属性的缺失值被赋值为B(0)=0.12、B(1)=0.88;所以属性B的缺失值以12%概率被分到False的分支,以88%概率被分到True的分支。这种处理的目的是计算信息增益,使得这种属性值缺失的样本也能处理。

最终结果

(1)分类树:最终叶子中概率最大的类
(2)回归树:最终叶子的均值或者中位数

优点

1)非常灵活,可以允许有部分错分成本,还可指定先验概率分布,可使用自动的成本复杂性剪枝来得到归纳性更强的树。
2)在面对诸如存在缺失值、变量数多等问题时CART显得非常稳健。

其实一直想看Cart:Classification and Regression Trees-Leo Breiman原版的书,可惜找不到,大家有谁找到能不能分享一下。

下面是我的实现,回归部分其实还没写,以后写了会更新一下。

数据集用的是UCI  adult数据集,大家可以搜搜

// cart.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include<vector>
#include<set>
#include<algorithm>
#include<iostream>
#include<iterator>
#include<fstream>
#include<string>
#include<map>
/*******************************************/
/************author Marshall****************/
/**********date 2015.10.3*******************/
/**************version 1.0******************/
/************copyright reserved*************/
/*******************************************/
using namespace std;



class cart
{
private:
    vector<int>nums_of_value_each_discreteAttri;
    int num_of_continuousAttri;
    int ContinuousAttriNums;
    int labelNums;//how many kinds of label
    unsigned int CL_max_height;
    //double miniumginigain;//not need,we have prune method

    //define the record
    class Record
    {
    public:
        vector<int>discrete_attri;//for each discrete attribute,it's value can be 0,1...increased by 1
        vector<double>continuous_attti;
        int label;//0,1,2...
    };

    //define the node
    struct CartNode
    {
        vector<int>remianDiscreteAttriID;
        int selectedAttriID;
        vector<int>selectedDiscreteAttriValues;
        bool isSelectedAttriIDDiscrete;
        double continuousAttriPartitionValue;//
        int label;//if the record drop in this node,its' label should be
        int height;//current node's height
        vector<int>labelcount;//a counter for the records' label that current node holds
        double alpha;//for nonleaf,for prune
        int record_number;//该节点上涵盖的记录个数
        CartNode*lnode, *rnode;
        CartNode()
        {
            label = -1;
            selectedAttriID = -1;
            isSelectedAttriIDDiscrete = true;
            lnode = rnode = NULL;
            record_number = 0;
        }
    };
    CartNode*root;


    //double threshold;

private:
    //calculate gini index,for classify
    double calGiniIndex(vector<int>&subdatasetbyID, const vector<Record>*dataset, CartNode*node = NULL);
    double calSquaredresiduals();//calculate squaredresiduals,for regression
    void CL_split_dataset();
    void RE_split_dataset();
    void CL_trim(const vector<Record>*validationdataset);
    void RE_trim();
    //void make_discrete();
    //pair.first is majority label in subdataset,pair.second is it's number
    int allthesame(vector<int>&subdatasetbyID, const vector<Record>*dataset);
    /*如果某特征取值有3个,那么二分序列组合就有3种,4个取值就有7种组合,5个取值就有15种组合*/
    vector<pair<vector<int>, vector<int>>>make_two_heap(const int kk);
    pair<vector<int>, vector<int>>split_dataset(const int&selectedDiscreteAttriID,
        vector<int>&selected, const vector<int>&subdatasetbyID, const vector<Record>*dataset);
    pair<vector<int>, vector<int>>split_dataset(const int&selectedContiuousAttriID,
        const double partition, const vector<int>&subdatasetbyID, const vector<Record>*dataset);
    CartNode* copytree(CartNode*src, CartNode*dst);//deepcopy of a tree,dst should be NUll
    void copynode(CartNode*src, CartNode*dst);
    void cal_alpha(CartNode*node);
    vector<CartNode*>getLeaf(CartNode*node);
    void destroyTree(CartNode*node);
    int labelNode(CartNode*node);
    void create_root();
    void build_tree_classify(vector<int>&subdatasetbyID,
        CartNode*node, const vector<Record>*dataset);
    void build_tree_regression();
public:
    void load_adult_dataset();
    int CART_classify(const Record dataset, CartNode*root = NULL);
    void CART_regression();
    void CART_trian(const vector<Record>*dataset, const vector<Record>*validationdataset);
    void CART_trian()
    {
        CART_trian(traindataset, validatedataset);
    }
    void set_paras();
    ~cart()
    {
        destroyTree(root);
        if (traindataset != NULL)
            delete traindataset;
        if (validatedataset != NULL)
            delete validatedataset;
    }
    vector<Record>*traindataset;//as it's name
    vector<Record>*validatedataset;
    vector<Record>*testdataset;
    void test(CartNode*node);
    void test();
};
void cart::test(CartNode*node)
{
    int errorNum = 0;
    for (int j = 0; j < testdataset->size(); j++)
    {
        errorNum += CART_classify((*testdataset)[j], node) == (*testdataset)[j].label ? 0 : 1;
    }
    cout << "测试集上的错误率为" << double(errorNum) / testdataset->size();

}

void cart::test()
{
    test(this->root);

}



void cart::set_paras()
{
    CL_max_height = 6;


}
void cart::CART_trian(const vector<Record>*dataset, const vector<Record>*validationdataset)
{
    create_root();
    set_paras();
    vector<int>subset;
    for (int i = 0; i < dataset->size(); i++)
        subset.push_back(i);
    build_tree_classify(subset, root, dataset);
    CL_trim(validationdataset);
}


void cart::destroyTree(CartNode*treeroot)
{
    _ASSERTE(treeroot != NULL);
    vector<CartNode*>pool, que;
    que.push_back(treeroot);
    while (!que.empty())
    {
        CartNode*node = que.back();
        que.pop_back();
        pool.push_back(node);
        if (node->lnode != NULL)
        {
            _ASSERTE(node->rnode != NULL);
            pool.push_back(node->lnode);
            pool.push_back(node->rnode);
        }
    }
    for (int i = 0; i < pool.size(); i++)
        delete pool[i];
}

void cart::copynode(CartNode*src, CartNode*dst)
{
    _ASSERTE(dst != NULL);
    _ASSERTE(src != NULL);
    dst->alpha = src->alpha;
    dst->continuousAttriPartitionValue = src->continuousAttriPartitionValue;
    dst->height = src->height;
    dst->isSelectedAttriIDDiscrete = src->isSelectedAttriIDDiscrete;
    dst->label = src->label;
    dst->labelcount = src->labelcount;
    dst->record_number = src->record_number;
    dst->remianDiscreteAttriID = src->remianDiscreteAttriID;
    dst->selectedAttriID = src->selectedAttriID;
    dst->selectedDiscreteAttriValues = src->selectedDiscreteAttriValues;

}

//implementation of tree copy
cart::CartNode* cart::copytree(CartNode*Srctreeroot, CartNode*Dsttreeroot)
{
    _ASSERTE(Dsttreeroot == NULL);
    _ASSERTE(Srctreeroot != NULL);

    vector<CartNode*>pool, parentpool;
    Dsttreeroot = new CartNode;
    copynode(Srctreeroot, Dsttreeroot);
    if (Srctreeroot->lnode == NULL)
    {
        _ASSERTE(Srctreeroot->rnode == NULL);
        return Dsttreeroot;
    }
    pool.push_back(Srctreeroot->lnode);
    pool.push_back(Srctreeroot->rnode);
    parentpool.push_back(Dsttreeroot);

    bool lnodeflag = false;
    while (!pool.empty())
    {
        CartNode*node = pool.back();
        pool.pop_back();
        CartNode*newnode = new CartNode;
        copynode(node, newnode);
        if (!lnodeflag)
            parentpool.back()->rnode = newnode;
        else
            parentpool.back()->lnode = newnode;
        if (node->lnode != NULL)
        {
            _ASSERTE(node->rnode != NULL);
            if (lnodeflag)
                parentpool.pop_back();
            lnodeflag = false;
            pool.push_back(node->lnode);
            pool.push_back(node->rnode);

            parentpool.push_back(newnode);
        }
        else
        {
            if (lnodeflag)
                parentpool.pop_back();
            else
                lnodeflag = !lnodeflag;
        }
    }
    _ASSERTE(parentpool.empty());
    _ASSERTE(Dsttreeroot);
    return Dsttreeroot;
}

int cart::CART_classify(const Record rd, CartNode*treeroot)
{
    if (treeroot == NULL)
        treeroot = this->root;
    CartNode*node = treeroot;
    while (true)
    {
        if (node->lnode == NULL)
        {
            _ASSERTE(node->rnode == NULL);
            return node->label;
        }
        if (node->isSelectedAttriIDDiscrete)
        {
            if (find(node->selectedDiscreteAttriValues.begin(),
                node->selectedDiscreteAttriValues.end(),
                rd.discrete_attri[node->selectedAttriID])
                == node->selectedDiscreteAttriValues.end())
            {
                node = node->rnode;
            }
            else
            {
                node = node->lnode;
            }
        }
        else
        {
            if (rd.continuous_attti[node->selectedAttriID] >= node->continuousAttriPartitionValue)
            {
                node = node->rnode;
            }
            else
            {
                node = node->lnode;
            }
        }
    }
    //should not run here
    _ASSERTE(false);
}


void cart::CL_trim(const vector<Record>*validationdataset)
{
    vector<CartNode*>candidateBestTree;
    CartNode*curretroot = root;
    while (curretroot->lnode != NULL)//&&root->rnode!=NULL
    {
        vector<CartNode*>pool;
        pool.push_back(curretroot);
        double min_alpha = 10000000;
        CartNode*tobecut = NULL;
        while (!pool.empty())
        {
            CartNode*node = pool.back();
            pool.pop_back();
            if (node->lnode != NULL)
            {
                _ASSERTE(node->rnode != NULL);
                cal_alpha(node);
                if (node->alpha < min_alpha)
                {
                    min_alpha = node->alpha;
                    tobecut = node;
                }
                pool.push_back(node->rnode);
                pool.push_back(node->lnode);
            }
        }
        _ASSERTE(tobecut != NULL);
        //then delete tobecut's child and son node
        vector<CartNode*>alltodel, temppool;
        temppool.push_back(tobecut);
        while (!temppool.empty())
        {
            CartNode*nn = temppool.back();
            temppool.pop_back();
            alltodel.push_back(nn);
            if (nn->lnode != NULL)
            {
                _ASSERTE(nn->rnode != NULL);
                temppool.push_back(nn->lnode);
                temppool.push_back(nn->rnode);
            }
        }
        alltodel.erase(find(alltodel.begin(), alltodel.end(), tobecut));
        for (int i = 0; i < alltodel.size(); i++)
            delete alltodel[i];
        tobecut->lnode = tobecut->rnode = NULL;



        candidateBestTree.push_back(curretroot);
        CartNode*nextroot = NULL;
        nextroot = copytree(curretroot, nextroot);
        _ASSERTE(nextroot);
        curretroot = nextroot;
    }

    //get the best tree
    int minError = validationdataset->size();
    CartNode*besttree = NULL;
    int th = -1;
    vector<int>candidateBestTreeErrorNums;
    for (int i = 0; i < candidateBestTree.size(); i++)
    {
        int errorNum = 0;
        for (int j = 0; j < validationdataset->size(); j++)
        {
            errorNum += CART_classify((*validationdataset)[j],
                candidateBestTree[i]) == (*validationdataset)[j].label ? 0 : 1;
        }
        //error /= (*validationdataset).size();
        candidateBestTreeErrorNums.push_back(errorNum);
        if (errorNum < minError)
        {
            minError = errorNum;
            th = i;
        }
    }

    test(candidateBestTree[th]);

    double SE = sqrt(double(minError*(validationdataset->size() - minError)) / validationdataset->size());
    for (int i = candidateBestTree.size() - 1; i >= 0; i--)
    {
        if (candidateBestTreeErrorNums[i] <= minError + SE)
        {
            besttree = candidateBestTree[i];
            th = i;
            break;
        }
    }
    candidateBestTree.erase(candidateBestTree.begin() + th);
    for (int i = 0; i < candidateBestTree.size(); i++)
        destroyTree(candidateBestTree[i]);
    _ASSERTE(besttree != NULL);
    root = besttree;
    cout << "剪枝后在验证集上的错误为" << (double)candidateBestTreeErrorNums[th] / validationdataset->size() << endl;
}


void cart::cal_alpha(CartNode*node)
{
    _ASSERTE(node->lnode != NULL&&node->rnode != NULL);
    int max_nodelabel = -1;
    for (int i = 0; i < labelNums; i++)
    {
        if (node->labelcount[i] > max_nodelabel)
        {
            max_nodelabel = node->labelcount[i];
        }
    }
    double Rt = double(max_nodelabel) / node->record_number*node->record_number / traindataset->size();
    double RTt = 0;
    vector<CartNode*>leafpool = getLeaf(node);
    for (int i = 0; i < leafpool.size(); i++)
    {
        RTt += double(leafpool[i]->record_number - leafpool[i]->labelcount[leafpool[i]->label]) /
            traindataset->size();
    }
    node->alpha = (Rt - RTt) / (leafpool.size() - 1);
}


vector<cart::CartNode*>cart::getLeaf(CartNode*node)
{
    vector<CartNode*>leafpool, que;
    que.push_back(node);
    while (!que.empty())
    {
        CartNode*nn = que.back();
        que.pop_back();
        if (nn->lnode != NULL)
            que.push_back(nn->lnode);
        else
        {
            _ASSERTE(nn->rnode == NULL);
            if (find(leafpool.begin(), leafpool.end(), nn) == leafpool.end())
                leafpool.push_back(nn);
        }

        if (nn->rnode != NULL)
            que.push_back(nn->rnode);
        else
        {
            _ASSERTE(nn->lnode == NULL);
            if (find(leafpool.begin(), leafpool.end(), nn) == leafpool.end())
                leafpool.push_back(nn);
        }
    }
    return leafpool;
}


pair<vector<int>, vector<int>>cart::split_dataset(const int&selectedDiscreteAttriID,
    vector<int>&selected, const vector<int>&subdatasetbyID, const vector<Record>*dataset)
{
    vector<int>aa, bb;
    for (int i = 0; i < subdatasetbyID.size(); i++)
    {
        if (find(selected.begin(), selected.end(), (*dataset)[subdatasetbyID[i]].
            discrete_attri[selectedDiscreteAttriID]) == selected.end())
        {
            bb.push_back(subdatasetbyID[i]);
        }
        else
            aa.push_back(subdatasetbyID[i]);
    }
    return pair<vector<int>, vector<int>>(aa, bb);
}

pair<vector<int>, vector<int>>cart::split_dataset(const int&selectedContiuousAttriID,
    const double partition, const vector<int>&subdatasetbyID, const vector<Record>*dataset)
{
    vector<int>aa, bb;
    for (int i = 0; i < subdatasetbyID.size(); i++)
    {
        if ((*dataset)[subdatasetbyID[i]].continuous_attti[selectedContiuousAttriID] >= partition)
        {
            bb.push_back(subdatasetbyID[i]);
        }
        else
            aa.push_back(subdatasetbyID[i]);
    }
    return pair<vector<int>, vector<int>>(aa, bb);

}
set<set<int>>solu;
void select(set<int>&selected, vector<int>&remain, int toselect)
{
    if (selected.size() == toselect)
    {
        if (solu.find(selected) == solu.end())
        {
            solu.insert(selected);
            //for (set<int>::iterator it = selected.begin(); it != selected.end(); it++)
            //  cout << *it << ",";
            //cout << endl;
        }
        return;
    }
    for (int i = 0; i < remain.size(); i++)
    {
        vector<int> re = remain;
        set<int>se = selected;
        se.insert(re[i]);
        re.erase(re.begin() + i);
        select(se, re, toselect);
    }
}
void Combination(vector<int>remain, int toselect)//组合  
{
    solu.clear();
    set<int>selected;
    select(selected, remain, toselect);
    //cout << "共有" << solu.size() << "种组合" << endl;
}

vector<pair<vector<int>, vector<int>>>cart::make_two_heap(const int kk)
{
    vector<pair<vector<int>, vector<int>>>toret;
    int len = nums_of_value_each_discreteAttri[kk];
    set<set<int>>re;
    vector<int>remain;
    for (int i = 0; i < len; i++)
        remain.push_back(i);
    for (int i = 1; i < len / 2 + 1; i++)
    {
        Combination(vector<int>(remain), i);
        re.insert(solu.begin(), solu.end());
    }
    for (set<set<int>>::iterator it = re.begin(); it != re.end(); it++)
    {
        vector<int>aa, bb;//bb(*it);
        set_difference(it->begin(), it->end(),
            remain.begin(), remain.end(), inserter(aa, aa.begin()));
        bb.insert(bb.begin(), it->begin(), it->end());

        toret.push_back(pair<vector<int>, vector<int>>(aa, bb));
    }
    return toret;
}

void cart::create_root()
{
    if (root == NULL)
    {
        root = new CartNode;
        for (int i = 0; i < nums_of_value_each_discreteAttri.size(); i++)
            root->remianDiscreteAttriID.push_back(i);
        root->height = 1;

    }
}

int cart::allthesame(vector<int>&subdatasetbyID, const vector<Record>*dataset)
{
    vector<int>count(labelNums);
    int label = ((*dataset)[subdatasetbyID[0]]).label;
    for (int i = 1; i < subdatasetbyID.size(); i++)
        if (((*dataset)[subdatasetbyID[i]]).label != label)
            return -1;
    return label;
}

//build classify tree recursively
void cart::build_tree_classify(vector<int>&subdatasetbyID,
    CartNode*node, const vector<Record>*dataset)
{
    node->record_number = subdatasetbyID.size();
    double basegini = calGiniIndex(subdatasetbyID, dataset, node);
    int currentlabel = allthesame(subdatasetbyID, dataset);
    if (currentlabel >= 0)
    {
        node->label = currentlabel;
        return;
    }
    if (node->height >= CL_max_height)
    {
        node->label = labelNode(node);
        return;
    }
    node->label = labelNode(node);
    double mingini = 10000000000;
    int selected = -1;
    bool isSelectedDiscrete = true;
    vector<int>selectedDiscreteAttriValues;
    pair<vector<int>, vector<int>>splited_subdataset;
    bool lnodeDecreaseDiscreteAttri = false;//is node's lnode's discrete attribute nums decrease
    bool rnodeDecreaseDiscreteAttri = false;



    //for discrete features,calculate giniindex
    for (int i = 0; i < node->remianDiscreteAttriID.size(); i++)
    {
        if (nums_of_value_each_discreteAttri[node->remianDiscreteAttriID[i]] > 2)
        {
            vector<pair<vector<int>, vector<int>>>bipart = make_two_heap(node->remianDiscreteAttriID[i]);
            for (int j = 0; j < bipart.size(); j++)
            {
                pair<vector<int>, vector<int>>two_subdataset = split_dataset(
                    node->remianDiscreteAttriID[i], bipart[i].first, subdatasetbyID, dataset);
                if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0)
                {
                    double gini1 = calGiniIndex(two_subdataset.first, dataset);
                    double gini2 = calGiniIndex(two_subdataset.second, dataset);
                    double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1
                        + double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2;
                    if (gini < mingini)
                    {
                        if (bipart[i].first.size() == 1)
                            lnodeDecreaseDiscreteAttri = true;
                        else
                            lnodeDecreaseDiscreteAttri = false;
                        if (bipart[i].second.size() == 1)
                            rnodeDecreaseDiscreteAttri = true;
                        else
                            rnodeDecreaseDiscreteAttri = false;
                        mingini = gini;
                        selected = node->remianDiscreteAttriID[i];
                        splited_subdataset = two_subdataset;
                        selectedDiscreteAttriValues = bipart[i].first;
                    }
                }
            }
        }
        else
        {
            vector<int>aa;
            aa.push_back(0);
            pair<vector<int>, vector<int>>two_subdataset = split_dataset(node->remianDiscreteAttriID[i],
                aa, subdatasetbyID, dataset);
            if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0)
            {
                double gini1 = calGiniIndex(two_subdataset.first, dataset);
                double gini2 = calGiniIndex(two_subdataset.second, dataset);
                double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1
                    + double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2;
                if (gini < mingini)
                {
                    mingini = gini;
                    selected = node->remianDiscreteAttriID[i];
                    splited_subdataset = two_subdataset;
                    lnodeDecreaseDiscreteAttri = true;
                    rnodeDecreaseDiscreteAttri = true;
                    selectedDiscreteAttriValues.clear();
                    selectedDiscreteAttriValues.push_back(0);
                }
            }
        }
    }
    // 利用函数对象实现升降排序    
    struct CompNameEx{
        CompNameEx(bool asce, int k, const vector<Record>*dataset) : asce_(asce), kk(k), dataset(dataset)
        {}
        bool operator()(int const& pl, int const& pr)
        {
            return asce_ ? (*dataset)[pl].continuous_attti[kk] < (*dataset)[pr].continuous_attti[kk]
                : (*dataset)[pr].continuous_attti[kk] < (*dataset)[pl].continuous_attti[kk];
            // 《Eff STL》条款21: 永远让比较函数对相等的值返回false    
        }
    private:
        bool asce_;
        int kk;
        const vector<Record>*dataset;
    };

    //for continuous features,calculate giniindex
    double partitionpoint;
    for (int i = 0; i < ContinuousAttriNums; i++)
    {
        sort(subdatasetbyID.begin(), subdatasetbyID.end(),
            CompNameEx(true, i, dataset));
        for (int j = 0; j < subdatasetbyID.size() - 1; j++)
        {
            double partition = 0.5*(*dataset)[subdatasetbyID[j]].continuous_attti[i] +
                0.5*(*dataset)[subdatasetbyID[j + 1]].continuous_attti[i];
            pair<vector<int>, vector<int>>two_subdataset =
                split_dataset(i, partition, subdatasetbyID, dataset);
            if (two_subdataset.first.size() > 0 && two_subdataset.second.size() > 0)
            {
                double gini1 = calGiniIndex(two_subdataset.first, dataset);
                double gini2 = calGiniIndex(two_subdataset.second, dataset);
                double gini = double(two_subdataset.first.size()) / subdatasetbyID.size()*gini1
                    + double(two_subdataset.second.size()) / subdatasetbyID.size()*gini2 + log(double(subdatasetbyID.size() - 2) / dataset->size()) / log(2.0);
                if (gini < mingini)
                {
                    partitionpoint = partition;
                    mingini = gini;
                    selected = i;
                    isSelectedDiscrete = false;
                    splited_subdataset = two_subdataset;
                }
            }
        }
    }

    //we have prune,so regardless of ginigain
    //double ginigain = basegini - mingini;//if not greater than miniumginigain;current node should not grow 

    if (splited_subdataset.first.size() > 0 && splited_subdataset.second.size() > 0)//&&ginigain>miniumginigain)
    {
        CartNode*lchild = new CartNode;
        CartNode*rchild = new CartNode;
        node->lnode = lchild;
        node->rnode = rchild;
        lchild->height = node->height + 1;
        rchild->height = node->height + 1;
        lchild->remianDiscreteAttriID = node->remianDiscreteAttriID;
        rchild->remianDiscreteAttriID = node->remianDiscreteAttriID;
        node->selectedAttriID = selected;
        if (isSelectedDiscrete)
        {
            if (lnodeDecreaseDiscreteAttri)
            {
                lchild->remianDiscreteAttriID.erase(find(lchild->
                    remianDiscreteAttriID.begin(), lchild->remianDiscreteAttriID.end(), selected));
            }
            if (rnodeDecreaseDiscreteAttri)
            {
                rchild->remianDiscreteAttriID.erase(find(rchild->
                    remianDiscreteAttriID.begin(), rchild->remianDiscreteAttriID.end(), selected));
            }
            node->selectedDiscreteAttriValues = selectedDiscreteAttriValues;
        }
        else
        {
            node->isSelectedAttriIDDiscrete = false;
            node->continuousAttriPartitionValue = partitionpoint;
        }
        //recursively call  build_tree_classify()
        build_tree_classify(splited_subdataset.first, lchild, dataset);

        build_tree_classify(splited_subdataset.second, rchild, dataset);
    }
}


double cart::calGiniIndex(vector<int>&subdatasetbyID, const vector<Record>*dataset, CartNode*node)
{
    _ASSERTE(subdatasetbyID.size() > 0);
    _ASSERTE(dataset != NULL);
    vector<int>count;
    count.resize(labelNums);
    for (int i = 0; i < subdatasetbyID.size(); i++)
    {
        count[((*dataset)[subdatasetbyID[i]]).label]++;
    }
    if (node != NULL)
    {
        node->labelcount = count;
        node->record_number = subdatasetbyID.size();
    }
    vector<double> probalblity;
    probalblity.resize(labelNums);
    double re = 1;
    for (int i = 0; i < labelNums; i++)
    {
        probalblity[i] = double(count[i]) / subdatasetbyID.size();
        re -= pow(probalblity[i], 2);
    }
    _ASSERTE(re >= 0);
    return re;
}

int cart::labelNode(CartNode*node)
{
    int label = -1;
    double maxpro = 0;
    for (int i = 0; i < labelNums; i++)
    {
        double temppro = double(node->labelcount[i]) / node->record_number;
        temppro /= double(root->labelcount[i]) / root->record_number;
        if (temppro > maxpro)
        {
            maxpro = temppro;
            label = i;
        }
    }
    _ASSERTE(label >= 0);
    return label;
}






int split(const std::string& str, std::vector<std::string>& ret_, std::string sep = ",")
{
    if (str.empty())
    {
        return 0;
    }

    std::string tmp;
    std::string::size_type pos_begin = str.find_first_not_of(sep);
    std::string::size_type comma_pos = 0;

    while (pos_begin != std::string::npos)
    {
        comma_pos = str.find(sep, pos_begin);
        if (comma_pos != std::string::npos)
        {
            tmp = str.substr(pos_begin, comma_pos - pos_begin);
            pos_begin = comma_pos + sep.length();
        }
        else
        {
            tmp = str.substr(pos_begin);
            pos_begin = comma_pos;
        }

        if (!tmp.empty())
        {
            ret_.push_back(tmp);
            tmp.clear();
        }
    }
    return 0;
}





//说明,因为education,workclass,marital-status,occupation,native country属性太多,不作考虑
void cart::load_adult_dataset()
{
    vector<Record>*traindataset;//as it's name
    vector<Record>*validatedataset;
    string filename = "adult.data";
    ifstream infile(filename.c_str());
    string temp;
    cout << endl;
    int count = 0;
    //vector<vector<std::string>>ss;
    traindataset = new vector < Record > ;
    validatedataset = new vector < Record > ;
    this->traindataset = traindataset;
    this->validatedataset = validatedataset;
    testdataset = new vector < Record > ;
    //Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked

    /*map<string, int>workclass;
    workclass["Private"] = 0;
    workclass["Self-emp-not-inc"] = 1;
    workclass["Self-emp-inc"] = 2;
    workclass["Federal-gov"] = 3;
    workclass["Local-gov"] = 4;
    workclass["State-gov"] = 5;
    workclass["Without-pay"] = 6;
    workclass["Never-worked"] = 7;*/

    //education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th,
    // 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.

    /*map<string, int>education;
    education["Bachelors"] = 0;
    education["Some-college"] = 1;
    education["11th"] = 2;
    education["HS-grad"] = 3;
    education["Prof-school"] = 4;
    education["Assoc-acdm"] = 5;
    education["Assoc-voc"] = 6;
    education["9th"] = 7;
    education["7th-8th"] = 8;
    education["12th"] = 9;
    education["Masters"] = 10;
    education["1st-4th"] = 11;
    education["10th"] = 12;
    education["Doctorate"] = 13;
    education["5th-6th"] = 14;
    education["Preschool"] = 15;
    */
    //marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed,
    // Married-spouse-absent, Married-AF-spouse.
    /*map<string, int>marital_status;
    marital_status["Married-civ-spouse"] = 0;
    marital_status["Divorced"] = 1;
    marital_status["Never-married"] = 2;
    marital_status["Separated"] = 3;
    marital_status["Widowed"] = 4;
    marital_status["Married-spouse-absent"] = 5;
    marital_status["Married-AF-spouse"] = 6;*/

    //occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, 
    //Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing,
    // Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
    /*map<string, int>occupation;
    occupation["Tech-support"] = 0;
    occupation["Craft-repair"] = 1;
    occupation["Other-service"] = 2;
    occupation["Sales"] = 3;
    occupation["Exec-managerial"] = 4;
    occupation["Prof-specialty"] = 5;
    occupation["Handlers-cleaners"] = 6;
    occupation["Machine-op-inspct"] = 7;
    occupation["Adm-clerical"] = 8;
    occupation["Farming-fishing"] = 9;
    occupation["Transport-moving"] = 10;
    occupation["Priv-house-serv"] = 11;
    occupation["Protective-serv"] = 12;
    occupation["Armed-Forces"] = 13;
    */

    //relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.

    map<string, int>relationship;
    relationship["Wife"] = 0;
    relationship["Own-child"] = 1;
    relationship["Husband"] = 2;
    relationship["Not-in-family"] = 3;
    relationship["Other-relative"] = 4;
    relationship["Unmarried"] = 5;

    //race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.

    map<string, int>race;
    race["White"] = 0;
    race["Asian-Pac-Islander"] = 1;
    race["Amer-Indian-Eskimo"] = 2;
    race["Other"] = 3;
    race["Black"] = 4;

    //sex: Female, Male.
    map<string, int>sex;
    sex["Female"] = 0;
    sex["Male"] = 1;

    //native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, 
    //Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran,
    // Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, 
    //Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia,
    // Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, 
    //Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
    map<string, int>label;
    label["<=50K"] = 0;
    label[">50K"] = 1;


    while (getline(infile, temp) && count < 7000)
    {

        Record rd;
        rd.continuous_attti.resize(6);
        rd.discrete_attri.resize(3);
        //cout << temp << endl;

        std::vector<std::string>re;
        split(temp, re, std::string(", "));
        bool desert = false;
        if (re.size() == 15)
        {

            /*age: continuous.
            workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
            fnlwgt: continuous.
            education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
            education-num: continuous.
            marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
            occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
            relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
            race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
            sex: Female, Male.
            capital-gain: continuous.
            capital-loss: continuous.
            hours-per-week: continuous.
            native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.*/


            //age continuous
            rd.continuous_attti[0] = atoi(re[0].c_str());

            //workclass discrete
            /*if (workclass.find(re[1]) != workclass.end())
                rd.discrete_attri[0] = workclass[re[1]];
                else
                desert=true;*/

            //fnlwgt: continuous
            rd.continuous_attti[1] = atoi(re[2].c_str());

            //education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
            /*if (education.find(re[3]) != education.end())
                rd.discrete_attri[1] = education[re[3]];
                else
                desert=true;*/

            //education-num: continuous.
            rd.continuous_attti[2] = atoi(re[4].c_str());

            //marital-status
            /*if (marital_status.find(re[5]) != marital_status.end())
                rd.discrete_attri[1] = marital_status[re[5]];
                else
                desert=true;*/

            //relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
            if (relationship.find(re[7]) != relationship.end())
                rd.discrete_attri[0] = relationship[re[7]];
            else
                desert = true;

            //race
            if (race.find(re[8]) != race.end())
                rd.discrete_attri[1] = race[re[8]];
            else
                desert = true;

            //sex
            if (sex.find(re[9]) != sex.end())
                rd.discrete_attri[2] = sex[re[9]];
            else
                desert = true;

            //capital-gain: continuous.
            rd.continuous_attti[3] = atoi(re[10].c_str());

            //capital-loss: continuous.

            rd.continuous_attti[4] = atoi(re[11].c_str());
            //hours-per-week: continuous
            rd.continuous_attti[5] = atoi(re[12].c_str());

            if (label.find(re[14]) != label.end())
                rd.label = label[re[14]];
            else
                desert = true;
            if (!desert)
                if (count < 3500)
                {
                    traindataset->push_back(rd);
                }
                else if (count < 4500)
                {
                    validatedataset->push_back(rd);
                }
                else
                    testdataset->push_back(rd);
        }
        count++;
    }
    ContinuousAttriNums = 6;
    labelNums = 2;
    int aa[3] = { 6, 5, 2 };
    nums_of_value_each_discreteAttri.push_back(6);
    nums_of_value_each_discreteAttri.push_back(5);
    nums_of_value_each_discreteAttri.push_back(2);


}

int _tmain(int argc, _TCHAR* argv[])
{

    cart cart;
    cart.load_adult_dataset();
    cart.CART_trian();
    cart.test();
    system("pause");
    return 0;
}

可能不太完善,大体框架是这样了,具体细节可能处理不好。欢迎大家指点。

遗留问题:

先验概率和分类平衡

误分类成本的引入

支持权重,对于不同的样本赋予不同的权重值

动态特征构架

值敏感学习

概率树

回归树细节

模型树