8
8
import numpy as np
9
9
10
10
class DecitionTree ():
11
- """this is a decision tree classifier. """
11
+ """This is a decision tree classifier. """
12
12
13
- def __init__ (self , criteria = 'C4.5 ' ):
13
+ def __init__ (self , criteria = 'ID3 ' ):
14
14
self ._tree = None
15
15
if criteria == 'ID3' or criteria == 'C4.5' :
16
16
self ._criteria = criteria
@@ -19,9 +19,9 @@ def __init__(self, criteria='C4.5'):
19
19
20
20
def _calEntropy (slef , y ):
21
21
'''
22
- _calEntropy用于计算香农熵 e=-sum(pi*log pi)
23
- 其中y为数组array
24
- 输出entropy
22
+ 功能: _calEntropy用于计算香农熵 e=-sum(pi*log pi)
23
+ 参数: 其中y为数组array
24
+ 输出:信息熵entropy
25
25
'''
26
26
n = y .shape [0 ]
27
27
labelCounts = {}
@@ -38,7 +38,8 @@ def _calEntropy(slef, y):
38
38
39
39
def _splitData (self , X , y , axis , cutoff ):
40
40
"""
41
- 该函数返回数据集中特征下标为axis,特征值等于cutoff的子数据集
41
+ 参数:X为特征,y为label,axis为某个特征的下标,cutoff是下标为axis特征取值值
42
+ 输出:返回数据集中特征下标为axis,特征值等于cutoff的子数据集
42
43
"""
43
44
ret = []
44
45
featVec = X [:,axis ]
@@ -51,7 +52,9 @@ def _splitData(self, X, y, axis, cutoff):
51
52
52
53
def _chooseBestSplit (self , X , y ):
53
54
"""ID3 & C4.5
54
- 根据信息增益或者信息增益率来获取最好的划分特征
55
+ 参数:X为特征,y为label
56
+ 功能:根据信息增益或者信息增益率来获取最好的划分特征
57
+ 输出:返回最好划分特征的下标
55
58
"""
56
59
numFeat = X .shape [1 ]
57
60
baseEntropy = self ._calEntropy (y )
@@ -83,17 +86,23 @@ def _chooseBestSplit(self, X, y):
83
86
84
87
def _majorityCnt (self , labellist ):
85
88
"""
86
- 返回labellist中出现次数最多的label
89
+ 参数:labellist是类标签,序列类型为list
90
+ 输出:返回labellist中出现次数最多的label
87
91
"""
88
92
labelCount = {}
89
93
for vote in labellist :
90
- if vote not in labelCount .keys (): labelCount [vote ] = 0
94
+ if vote not in labelCount .keys ():
95
+ labelCount [vote ] = 0
91
96
labelCount [vote ] += 1
92
97
sortedClassCount = sorted (labelCount .iteritems (), key = lambda x :x [1 ], \
93
98
reverse = True )
94
99
return sortedClassCount [0 ][0 ]
95
100
96
101
def _createTree (self , X , y , featureIndex ):
102
+ """
103
+ 参数:X为特征,y为label,featureIndex类型是元组,记录X特征在原始数据中的下标
104
+ 输出:根据当前的featureIndex创建一颗完整的树
105
+ """
97
106
labelList = list (y )
98
107
if labelList .count (labelList [0 ]) == len (labelList ):
99
108
return labelList [0 ]
@@ -110,12 +119,16 @@ def _createTree(self, X, y, featureIndex):
110
119
for value in uniqueVals :
111
120
#对每个value递归地创建树
112
121
sub_X , sub_y = self ._splitData (X ,y , bestFeatIndex , value )
113
- myTree [bestFeatAxis ][value ] = self ._createTree (sub_X ,sub_y ,\
122
+ myTree [bestFeatAxis ][value ] = self ._createTree (sub_X , sub_y , \
114
123
featureIndex )
115
124
return myTree
116
125
117
126
def fit (self , X , y ):
118
- #对数据X和y进行类型检测,保证其为array
127
+ """
128
+ 参数:X是特征,y是类标签
129
+ 注意事项:对数据X和y进行类型检测,保证其为array
130
+ 输出:self本身
131
+ """
119
132
if isinstance (X , np .ndarray ) and isinstance (y , np .ndarray ):
120
133
pass
121
134
else :
@@ -129,41 +142,49 @@ def fit(self, X, y):
129
142
return self #allow using: clf.fit().predict()
130
143
131
144
def _classify (self , tree , sample ):
132
- featIndex = tree .keys ()[0 ]
133
- secondDict = tree [featIndex ]
134
- key = sample [int (featIndex [1 :])]
135
- valueOfkey = secondDict [key ]
136
- if type (valueOfkey ).__name__ == 'dict' :
137
- return self ._classify (valueOfkey , sample )
145
+ """
146
+ 用训练好的模型对输入数据进行分类
147
+ 注意:决策树的构建是一个递归的过程,用决策树分类也是一个递归的过程
148
+ _classify()一次只能对一个样本(sample)分类
149
+ """
150
+ featIndex = tree .keys ()[0 ] #得到数的根节点值
151
+ secondDict = tree [featIndex ] #得到以featIndex为划分特征的结果
152
+ axis = featIndex [1 :] #得到根节点特征在原始数据中的下标
153
+ key = sample [int (axis )] #获取待分类样本中下标为axis的值
154
+ valueOfKey = secondDict [key ] #获取secondDict中keys为key的value值
155
+ if type (valueOfKey ).__name__ == 'dict' : #如果value为dict,则继续递归分类
156
+ return self ._classify (valueOfKey , sample )
138
157
else :
139
- return valueOfkey
158
+ return valueOfKey
140
159
141
160
def predict (self , X ):
142
161
if self ._tree == None :
143
162
raise NotImplementedError ("Estimator not fitted, call `fit` first" )
144
- if isinstance (X ,np .ndarray ):
163
+ #对X的类型进行检测,判断其是否是数组
164
+ if isinstance (X , np .ndarray ):
145
165
pass
146
166
else :
147
167
try :
148
168
X = np .array (X )
149
169
except :
150
170
raise TypeError ("numpy.ndarray required for X" )
151
171
152
- if len (X .shape )== 1 :
172
+ if len (X .shape ) == 1 :
153
173
return self ._classify (self ._tree , X )
154
174
else :
155
175
result = []
156
176
for i in range (X .shape [0 ]):
157
- result .append (self ._classify (self ._tree , X [i ]))
177
+ value = self ._classify (self ._tree , X [i ])
178
+ print str (i + 1 )+ "-th sample is classfied as:" , value
179
+ result .append (value )
158
180
return np .array (result )
159
181
160
- def show (self ):
182
+ def show (self , outpdf ):
161
183
if self ._tree == None :
162
184
pass
163
185
#plot the tree using matplotlib
164
186
import treePlotter
165
- treePlotter .createPlot (self ._tree )
166
-
187
+ treePlotter .createPlot (self ._tree , outpdf )
167
188
168
189
if __name__ == "__main__" :
169
190
trainfile = r"data\train.txt"
@@ -173,8 +194,10 @@ def show(self):
173
194
import dataload as dload
174
195
train_x , train_y = dload .loadData (trainfile )
175
196
test_x , test_y = dload .loadData (testfile )
176
- clf = DecitionTree ()
197
+
198
+ clf = DecitionTree (criteria = "C4.5" )
177
199
clf .fit (train_x , train_y )
178
200
result = clf .predict (test_x )
179
- clf .show ()
201
+ outpdf = r"tree.pdf"
202
+ clf .show (outpdf )
180
203
0 commit comments