在《机器学习学习笔记(18)----CART(Classification And Regression Tree)算法》,我们给出了CART分类树的特征划分的算法,接下来,我们用python实现一个CART分类树算法(cartctree.py)(参考自《Python机器学习算法:原理,实现与案例》):
import numpy as npclass CartClassificationTree:class Node:def __init__(self):self.value = None# 内部叶节点属性self.feature_index = Noneself.feature_value = Noneself.left = Noneself.right = Nonedef __init__(self, gini_threshold=0.01, gini_dec_threshold=0.,min_samples_split=2):#基尼系数的阈值self.gini_threshold = gini_threshold#基尼系数降低的阈值self.gini_dec_threshold = gini_dec_threshold#数据集还可继续切分的最小样本数量self.min_samples_split = min_samples_splitdef _gini(self, y):#计算基尼系数values = np.unique(y)s = 0.for v in values:y_sub = y[y == v]s += (y_sub.size / y.size) **2return 1-sdef _gini_split(self, y, feature, value):#计算根据特征切分后的基尼系数#根据特征的值将数据集拆分成两个子集indices= feature > valuey1 = y[indices]y2 = y[~indices]#分别计算两个子集的基尼系数gini1 = self._gini(y1)gini2 = self._gini(y2)#计算切分后的基尼系数gini = (y1.size * gini1 + y2.size * gini2)/y.sizereturn ginidef _get_split_points(self, feature):#获取一个连续特征值的所有切分点#获取一个特征所有出现过的值并排序values = np.unique(feature)#切分点为values中相邻两个点的中点split_points = [(v1+v2)/2 for v1, v2 in zip(values[:-1],values[1:])]return split_pointsdef _select_feature(self, X, y):#选择划分特征#最佳切分特征的indexbest_feature_index = None#最佳切分点best_split_value = Nonemin_gini = np.inf_, n = X.shapefor feature_index in range(n):#迭代每一个特征feature = X[:, feature_index]#获得一个特征的所有切分点split_points = self._get_split_points(feature)for value in split_points:#迭代每一个切分点value,计算使用value切分后的数据集的基尼系数gini = self._gini_split(y, feature, value)#若找到更小的基尼系数,则更新切分特征if gini < min_gini:min_gini = ginibest_feature_index = feature_indexbest_split_value = value#判断切分后的基尼系数的降低是否超过阈值if self._gini(y) - min_gini < self.gini_dec_threshold:best_feature_index = Nonebest_split_value = Nonereturn best_feature_index, best_split_value, min_ginidef _node_value(self, y):#计算节点的值#统计数据集中样本类标记的个数labels_count = np.bincount(y)#任何情况下节点值总等于数据集中样本最多的类标记return np.argmax(labels_count)def _build_tree(self, X, y):#决策树构造算法(递归)#创建节点node = CartClassificationTree.Node()#计算节点的值node.value = self._node_value(y)#若当前数据集样本数量小于最小切分数量min_samples_split,则返回叶节点if y.size < self.min_samples_split:return node #若当前数据集的基尼系数小于阈值gini_threshold,则返回叶节点if self._gini(y) < self.gini_threshold:return node#选择最佳切分特征feature_index, feature_value, min_gini = self._select_feature(X,y)if feature_index is not None:#如果存在适合切分特征,则当前节点为子节点node.feature_index = feature_indexnode.feature_value = feature_value#根据已选择特征及切分点将数据集划分成两个子集feature = X[:, feature_index]indices = feature > feature_valueX1, y1 = X[indices], y[indices]X2, y2 = X[~indices], y[~indices] #使用数据子集创建左右子树node.left = self._build_tree(X1, y1)node.right = self._build_tree(X2, y2)return nodedef _predict_one(self, x):#搜索决策树,对单个实例进行预测node = _while node.left:if x[node.feature_index] > node.feature_value:node = node.leftelse:node = node.rightreturn node.valuedef train(self, X_train, y_train):#训练_ = self._build_tree(X_train, y_train)def predict(self, X):#对每一个实例使用_predict_one,返回收集到的结果数组return np.apply_along_axis(self._predict_one, axis=1, arr=X)
此程序适用于可用通过数字表示的特征值,这样通过上一篇文章根据连续特征值的寻找最佳分割阈值的方法,找到最适合的分割阈值,从而进一步找到最适合的特征列。通过递归调用,构造左右子树,最终得到整个决策树。
接下来,验证一下效果,选择鸢尾花数据集进行验证():
下载iris.data和iris.names两个文件。从iris.names文件可以了解数据集的属性:
列号 | 列名 | 特征/类标记 | 可取值 |
1 | sepal length | 特征 | 连续实数 |
2 | sepal width | 特征 | 连续实数 |
3 | petal length | 特征 | 连续实数 |
4 | petal width | 特征 | 连续实数 |
5 | class | 类标记 | Iris-versicolor Iris-virginica Iris-setosa |
该数据集共有150条记录,通过以下代码进行训练:
>>> import numpy as np
>>> X = np.genfromtxt('iris.data',delimiter=',',usecols=range(4),dtype=np.float)
>>> y = np.genfromtxt('iris.data',delimiter=',',usecols=4,dtype=np.str)
>>> y
array(['Iris-setosa', 'Iris-setosa', 'Iris-setosa', 'Iris-setosa',......'Iris-virginica', 'Iris-virginica', 'Iris-virginica','Iris-virginica', 'Iris-virginica'], dtype='<U15')
>>> from sklearn.preprocessing import LabelEncoder
>>> le = LabelEncoder()
>>> y = le.fit_transform(y)
>>> y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], dtype=int32)
>>> from cartctree import CartClassificationTree
>>> cct = CartClassificationTree()
>>> del_selection import train_test_split
>>> X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3)
>>> ain(X_train, y_train)
观察训练效果:
>>> ics import accuracy_score
>>> y_predict = cct.predict(X_test)
>>> y_predict
array([2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 0,2, 2, 0, 2, 2, 0, 1, 1, 0, 1, 2, 2, 0, 2, 2, 1, 0, 1, 0, 0, 2, 0,1], dtype=int32)
>>> y_test
array([2, 2, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 1, 2, 1, 0,2, 2, 0, 2, 2, 0, 1, 1, 0, 1, 1, 2, 0, 2, 2, 1, 0, 1, 0, 0, 2, 0,1], dtype=int32)
>>> accuracy_score(y_test, y_predict)
0.9555555555555556
准确率超过95%。
上面的训练的决策树结果还不够直观,我们需要绘制一颗决策树,上次在《机器学习学习笔记(16)----使用Matplotlib绘制决策树》,绘制过一颗决策树,但是那个Node节点的结构和这个CART分类树算法的Node节点不大一样,因此不能直接使用那篇文章的代码,需要修改适配一下(treeplotter2.py):
import matplotlib.pyplot as plt
from cartctree import CartClassificationTreeclass TreePlotter2:def __init__(self, tree, feature_names, label_names):self.decision_node = dict(boxstyle="sawtooth", fc="0.8")self.leaf_node = dict(boxstyle="round4", fc="0.8")self.arrow_args = dict(arrowstyle="<-")#保存决策树 = tree#保存特征名字字典self.feature_names=feature_names#保存类标记名字字典self.label_names=alW = alD = Noneself.xOff = Noneself.yOff = Nonedef _get_num_leafs(self, node):'''获取叶节点的个数'''if not node.left:return 1num_leafs = 0num_leafs += self._get_num_leafs(node.left)num_leafs += self._get_num_leafs(node.right)return num_leafsdef _get_tree_depth(self, node):'''获取树的深度'''if not node.left:return 1max_depth = 0this_depth1 = 1 + self._get_tree_depth(node.left)this_depth2 = 1 + self._get_tree_depth(node.right)if(this_depth1 > this_depth2):max_depth = this_depth1else:max_depth = this_depth2return max_depthdef _plot_mid_text(self, cntrpt, parentpt, txtstring, ax1) :'''在父子节点之间填充文本信息'''x_mid = (parentpt[0] - cntrpt[0])/2.0 + cntrpt[0]y_mid = (parentpt[1] - cntrpt[1])/2.0 + cntrpt[(x_mid, y_mid, txtstring)def _plot_node(self, nodetxt, centerpt, parentpt, nodetype, ax1):ax1.annotate(nodetxt, xy= parentpt,xycoords= 'axes fraction',xytext=centerpt, textcoords='axes fraction',va="center", ha="center", bbox=nodetype, arrowprops= self.arrow_args)def _plot_tree(self, tree, parentpt, nodetxt, ax1):#子树的叶节点个数,总宽度num_leafs = self._get_num_leafs(tree)#子树的根节点名称tree_name = self.feature_names[tree.feature_index]['name']#计算子树根节点的位置cntrpt = (self.xOff + (1.0 + float(num_leafs))/2.alW, self.yOff)#画子树根节点与父节点中间的文字self._plot_mid_text(cntrpt, parentpt, nodetxt, ax1)#画子树的根节点,与父节点间的连线,箭头。self._plot_node(tree_name, cntrpt, parentpt, self.decision_node, ax1)#计算下级节点的y轴位置self.yOff = self.yOff - 1.alDif tree.left:child = tree.leftif child.left:#如果是子树,递归调用_plot_treeself._plot_tree(child, cntrpt, self.feature_names[tree.feature_index]['value_names'][1]+str(tree.feature_value), ax1)else:#如果是叶子节点,计算叶子节点的x轴位置self.xOff = self.xOff + 1.alW#如果是叶子节点,画叶子节点,以及叶子节点与父节点之间的连线,箭头。self._plot_node(self.label_names[child.value], (self.xOff, self.yOff), cntrpt, self.leaf_node, ax1)#如果是叶子节点,画叶子节点与父节点之间的中间文字。self._plot_mid_text((self.xOff, self.yOff), cntrpt, self.feature_names[tree.feature_index]['value_names'][1]+str(tree.feature_value), ax1)child = tree.rightif child.right:#如果是子树,递归调用_plot_treeself._plot_tree(child, cntrpt, self.feature_names[tree.feature_index]['value_names'][2]+str(tree.feature_value), ax1)else:#如果是叶子节点,计算叶子节点的x轴位置self.xOff = self.xOff + 1.alW#如果是叶子节点,画叶子节点,以及叶子节点与父节点之间的连线,箭头。self._plot_node(self.label_names[child.value], (self.xOff, self.yOff), cntrpt, self.leaf_node, ax1)#如果是叶子节点,画叶子节点与父节点之间的中间文字。self._plot_mid_text((self.xOff, self.yOff), cntrpt, self.feature_names[tree.feature_index]['value_names'][2]+str(tree.feature_value), ax1)#还原self.yOffself.yOff = self.yOff + 1.alDdef create_plot(self):fig = plt.figure(1, facecolor='white')fig.clf()#去掉边框axprops=dict(xticks=[], yticks=[])ax1 = plt.subplot(111, frameon=False, **axprops)#树的叶节点个数,总宽度alW = float(self._get_num_))#树的深度,总高度alD = float(self._get_tree_))self.xOff = -0.alWself.yOff = 1.0#树根节点位置固定放在(0.5,1.0)位置,就是中央的最上方self._plot_, (0.5,1.0), '', ax1)plt.show()
使用如下代码绘制上面的CART分类树:
>>> features_dict = {0 : {'name' : 'sepal length','value_names': { 1: '>',2: '<='}},1 : {'name' : 'sepal width','value_names': { 1: '>',2: '<='}},2 : {'name' : 'petal length','value_names': { 1: '>',2: '<='}},3 : {'name' : 'petal width','value_names': { 1: '>',2: '<='}}
}
>>> label_dict = {0: 'Iris-setosa',1: 'Iris-versicolor',2: 'Iris-virginica'
}
>>> from treeplotter2 import TreePlotter2
>>> plotter = _, features_dict, label_dict)
>>> ate_plot()
可得CART二叉分类树如下图:
参考资料:
《Python机器学习算法:原理,实现与案例》 刘硕 著
本文发布于:2024-02-03 02:33:41,感谢您对本站的认可!
本文链接:https://www.4u4v.net/it/170689881848068.html
版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。
留言与评论(共有 0 条评论) |