Skip to content

Commit 2de8779

Browse files
committed
update decision tree
1 parent ed0b8fd commit 2de8779

File tree

5 files changed

+107
-27
lines changed

5 files changed

+107
-27
lines changed

DecisionTree/DT.py

+49-26
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import numpy as np
99

1010
class DecitionTree():
11-
"""this is a decision tree classifier. """
11+
"""This is a decision tree classifier. """
1212

13-
def __init__(self, criteria='C4.5'):
13+
def __init__(self, criteria='ID3'):
1414
self._tree = None
1515
if criteria == 'ID3' or criteria == 'C4.5':
1616
self._criteria = criteria
@@ -19,9 +19,9 @@ def __init__(self, criteria='C4.5'):
1919

2020
def _calEntropy(slef, y):
2121
'''
22-
_calEntropy用于计算香农熵 e=-sum(pi*log pi)
23-
其中y为数组array
24-
输出entropy
22+
功能:_calEntropy用于计算香农熵 e=-sum(pi*log pi)
23+
参数:其中y为数组array
24+
输出:信息熵entropy
2525
'''
2626
n = y.shape[0]
2727
labelCounts = {}
@@ -38,7 +38,8 @@ def _calEntropy(slef, y):
3838

3939
def _splitData(self, X, y, axis, cutoff):
4040
"""
41-
该函数返回数据集中特征下标为axis,特征值等于cutoff的子数据集
41+
参数:X为特征,y为label,axis为某个特征的下标,cutoff是下标为axis特征取值值
42+
输出:返回数据集中特征下标为axis,特征值等于cutoff的子数据集
4243
"""
4344
ret = []
4445
featVec = X[:,axis]
@@ -51,7 +52,9 @@ def _splitData(self, X, y, axis, cutoff):
5152

5253
def _chooseBestSplit(self, X, y):
5354
"""ID3 & C4.5
54-
根据信息增益或者信息增益率来获取最好的划分特征
55+
参数:X为特征,y为label
56+
功能:根据信息增益或者信息增益率来获取最好的划分特征
57+
输出:返回最好划分特征的下标
5558
"""
5659
numFeat = X.shape[1]
5760
baseEntropy = self._calEntropy(y)
@@ -83,17 +86,23 @@ def _chooseBestSplit(self, X, y):
8386

8487
def _majorityCnt(self, labellist):
8588
"""
86-
返回labellist中出现次数最多的label
89+
参数:labellist是类标签,序列类型为list
90+
输出:返回labellist中出现次数最多的label
8791
"""
8892
labelCount={}
8993
for vote in labellist:
90-
if vote not in labelCount.keys(): labelCount[vote] = 0
94+
if vote not in labelCount.keys():
95+
labelCount[vote] = 0
9196
labelCount[vote] += 1
9297
sortedClassCount = sorted(labelCount.iteritems(), key=lambda x:x[1], \
9398
reverse=True)
9499
return sortedClassCount[0][0]
95100

96101
def _createTree(self, X, y, featureIndex):
102+
"""
103+
参数:X为特征,y为label,featureIndex类型是元组,记录X特征在原始数据中的下标
104+
输出:根据当前的featureIndex创建一颗完整的树
105+
"""
97106
labelList = list(y)
98107
if labelList.count(labelList[0]) == len(labelList):
99108
return labelList[0]
@@ -110,12 +119,16 @@ def _createTree(self, X, y, featureIndex):
110119
for value in uniqueVals:
111120
#对每个value递归地创建树
112121
sub_X, sub_y = self._splitData(X,y, bestFeatIndex, value)
113-
myTree[bestFeatAxis][value] = self._createTree(sub_X,sub_y,\
122+
myTree[bestFeatAxis][value] = self._createTree(sub_X, sub_y, \
114123
featureIndex)
115124
return myTree
116125

117126
def fit(self, X, y):
118-
#对数据X和y进行类型检测,保证其为array
127+
"""
128+
参数:X是特征,y是类标签
129+
注意事项:对数据X和y进行类型检测,保证其为array
130+
输出:self本身
131+
"""
119132
if isinstance(X, np.ndarray) and isinstance(y, np.ndarray):
120133
pass
121134
else:
@@ -129,41 +142,49 @@ def fit(self, X, y):
129142
return self #allow using: clf.fit().predict()
130143

131144
def _classify(self, tree, sample):
132-
featIndex = tree.keys()[0]
133-
secondDict = tree[featIndex]
134-
key = sample[int(featIndex[1:])]
135-
valueOfkey = secondDict[key]
136-
if type(valueOfkey).__name__=='dict':
137-
return self._classify(valueOfkey, sample)
145+
"""
146+
用训练好的模型对输入数据进行分类
147+
注意:决策树的构建是一个递归的过程,用决策树分类也是一个递归的过程
148+
_classify()一次只能对一个样本(sample)分类
149+
"""
150+
featIndex = tree.keys()[0] #得到数的根节点值
151+
secondDict = tree[featIndex] #得到以featIndex为划分特征的结果
152+
axis=featIndex[1:] #得到根节点特征在原始数据中的下标
153+
key = sample[int(axis)] #获取待分类样本中下标为axis的值
154+
valueOfKey = secondDict[key] #获取secondDict中keys为key的value值
155+
if type(valueOfKey).__name__=='dict': #如果value为dict,则继续递归分类
156+
return self._classify(valueOfKey, sample)
138157
else:
139-
return valueOfkey
158+
return valueOfKey
140159

141160
def predict(self, X):
142161
if self._tree==None:
143162
raise NotImplementedError("Estimator not fitted, call `fit` first")
144-
if isinstance(X,np.ndarray):
163+
#对X的类型进行检测,判断其是否是数组
164+
if isinstance(X, np.ndarray):
145165
pass
146166
else:
147167
try:
148168
X = np.array(X)
149169
except:
150170
raise TypeError("numpy.ndarray required for X")
151171

152-
if len(X.shape)==1:
172+
if len(X.shape) == 1:
153173
return self._classify(self._tree, X)
154174
else:
155175
result = []
156176
for i in range(X.shape[0]):
157-
result.append(self._classify(self._tree, X[i]))
177+
value = self._classify(self._tree, X[i])
178+
print str(i+1)+"-th sample is classfied as:", value
179+
result.append(value)
158180
return np.array(result)
159181

160-
def show(self):
182+
def show(self, outpdf):
161183
if self._tree==None:
162184
pass
163185
#plot the tree using matplotlib
164186
import treePlotter
165-
treePlotter.createPlot(self._tree)
166-
187+
treePlotter.createPlot(self._tree, outpdf)
167188

168189
if __name__=="__main__":
169190
trainfile=r"data\train.txt"
@@ -173,8 +194,10 @@ def show(self):
173194
import dataload as dload
174195
train_x, train_y = dload.loadData(trainfile)
175196
test_x, test_y = dload.loadData(testfile)
176-
clf = DecitionTree()
197+
198+
clf = DecitionTree(criteria="C4.5")
177199
clf.fit(train_x, train_y)
178200
result = clf.predict(test_x)
179-
clf.show()
201+
outpdf = r"tree.pdf"
202+
clf.show(outpdf)
180203

DecisionTree/readme.md

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
## Decision Tree
2+
3+
决策树理论详解:http://www.csuldw.com/2015/05/08/2015-05-08-decision%20tree/
4+
5+
- data存放数据集
6+
- calIG.py:计算信息增益的实例代码
7+
- DT.py:决策树实现
8+
- treePlotter.py:决策树的可视化绘制
9+
10+
## 相关知识
11+
12+
- python
13+
- numpy
14+
- matplotlib
15+
16+
## dataset
17+
18+
- 训练集:./data/train.txt
19+
- 测试集:./data/test.txt
20+
21+
## Run
22+
23+
```
24+
if __name__=="__main__":
25+
trainfile=r"data\train.txt"
26+
testfile=r"data\test.txt"
27+
import sys
28+
sys.path.append(r"F:\CSU\Github\MachineLearning\lib")
29+
import dataload as dload
30+
train_x, train_y = dload.loadData(trainfile)
31+
test_x, test_y = dload.loadData(testfile)
32+
33+
clf = DecitionTree(criteria="C4.5")
34+
clf.fit(train_x, train_y)
35+
result = clf.predict(test_x)
36+
outpdf = r"tree.pdf"
37+
#clf.show(outpdf)
38+
```
39+
40+
## Result
41+
42+
训练得到的树:https://github.com/csuldw/MachineLearning/tree/master/DecisionTree/tree.pdf
43+
44+
对test分类的结果:
45+
46+
```
47+
1-th sample is classfied as: 1
48+
2-th sample is classfied as: 0
49+
3-th sample is classfied as: 0
50+
```
51+
52+
## 参考资料
53+
54+
- 机器学习实战
55+
- Andrew Ng 机器学习公开课

DecisionTree/tree.pdf

67.5 KB
Binary file not shown.

DecisionTree/treePlotter.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
@author: Peter Harrington
55
'''
66
import matplotlib.pyplot as plt
7+
from matplotlib.pyplot import savefig
78

89
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
910
leafNode = dict(boxstyle="round4", fc="0.8")
@@ -59,7 +60,7 @@ def plotTree(myTree, parentPt, nodeTxt):#if the first key tells you what feat wa
5960
plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalD
6061
#if you do get a dictonary you know it's a tree, and the first element will be another dict
6162

62-
def createPlot(inTree):
63+
def createPlot(inTree, outpdf):
6364
fig = plt.figure(1, facecolor='white')
6465
fig.clf()
6566
axprops = dict(xticks=[], yticks=[])
@@ -69,6 +70,7 @@ def createPlot(inTree):
6970
plotTree.totalD = float(getTreeDepth(inTree))
7071
plotTree.xOff = -0.5/plotTree.totalW; plotTree.yOff = 1.0;
7172
plotTree(inTree, (0.5,1.0), '')
73+
plt.savefig(outpdf)
7274
plt.show()
7375

7476
#def createPlot():

DecisionTree/treePlotter.pyc

71 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)