利用回归树DecisionTreeRegressor进行回归预测(复习10)
阅读原文时间:2021年04月21日阅读:1

本文是个人学习笔记,内容主要基于回归树DecisionTreeRegressor对boston数据集学习回归模型和利用模型预测。

树模型可以解决非线性特征的问题,树模型不要求对特征标准化和统一量化(即数值型和类目型特征都可以直接被用到树模型的构建和预测过程),树模型可以直观地输出决策过程,使得预测结果具有可解释性。
使用树模型时要防止过拟合,对数据噪声的敏感度较高(预测稳定性较差),有训练数据构建最佳的树模型是NP难问题,因此实际操作时使用的类似贪婪算法的解法只能找到一些次优解。

回归树叶节点的数据类型是连续的,而分类树叶节点的数据类型是离散的。
回归树叶节点是一个个具体的值,而分类树叶节点是依据训练样本类别确定的预测类别。
回归树的叶节点返回的是“一团”训练数据的均值,而不是具体的、连续的预测值。

from sklearn.datasets import load_boston
boston=load_boston()
print(boston.DESCR)   #打印数据描述


print(boston.feature_names)

#Output:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO' 'B' 'LSTAT']

from sklearn.cross_validation import train_test_split
import numpy as np
X=boston.data
y=boston.target
X_train,X_test,y_train,y_test=train_test_split(X,y,random_state=33,test_size=0.25)

print('The max target value is',np.max(boston.target))
print('The min target value is',np.min(boston.target))
print('The average target value is',np.mean(boston.target))

from sklearn.tree import DecisionTreeRegressor

dtr=DecisionTreeRegressor()
dtr.fit(X_train,y_train)
dtr_y_predict=dtr.predict(X_test)

from sklearn.metrics import r2_score,mean_absolute_error,mean_squared_error

print('R-squared value of DecisionTreeRegressor:',dtr.score(X_test,y_test))
print('The mean squared error of DecisionTreeRegressor:',mean_squared_error(y_test,dtr_y_predict))
print('The mean absolute error of DecisionTreeRegressor:',mean_absolute_error(y_test,dtr_y_predict))

import sys
import os
os.environ["PATH"] += os.pathsep + 'D:\PYTHON35\Anaconda3.4.2\Lib\site-packages\graphviz-2.38\bin'
#'D:\PYTHON35\Anaconda3.4.2\Lib\site-packages\graphviz-2.38\bin'是解压缩graphviz-2.38.zip包后bin文件夹所在位置

%matplotlib inline
import numpy as np
from IPython.display import Image  
from sklearn import tree
import pydotplus 
import graphviz
dot_data = tree.export_graphviz(dtr, out_file=None, 
                         feature_names=boston.feature_names,  
                         class_names=['0','1'], 
                         filled=True, rounded=True,  
                         special_characters=True)   #feature_names格式是np.array
graph = pydotplus.graph_from_dot_data(dot_data)  
Image(graph.create_png())

手机扫一扫

移动阅读更方便

阿里云服务器
腾讯云服务器
七牛云服务器

你可能感兴趣的文章