12
12
class Tree (object ):
13
13
"""Recursive implementation of decision tree."""
14
14
15
- def __init__ (self , regression = False , criterion = None ):
15
+ def __init__ (self , regression = False , criterion = None , n_classes = None ):
16
16
self .regression = regression
17
17
self .impurity = None
18
18
self .threshold = None
19
19
self .column_index = None
20
20
self .outcome = None
21
21
self .criterion = criterion
22
22
self .loss = None
23
+ self .n_classes = n_classes # Only for classification
23
24
24
25
self .left_child = None
25
26
self .right_child = None
@@ -64,6 +65,42 @@ def _find_best_split(self, X, target, n_features):
64
65
max_col , max_val , max_gain = column , value , gain
65
66
return max_col , max_val , max_gain
66
67
68
+ def _train (self , X , target , max_features = None , min_samples_split = 10 , max_depth = None , minimum_gain = 0.01 ):
69
+ try :
70
+ # Exit from recursion using assert syntax
71
+ assert X .shape [0 ] > min_samples_split
72
+ assert max_depth > 0
73
+
74
+ if max_features is None :
75
+ max_features = X .shape [1 ]
76
+
77
+ column , value , gain = self ._find_best_split (X , target , max_features )
78
+ assert gain is not None
79
+ if self .regression :
80
+ assert gain != 0
81
+ else :
82
+ assert gain > minimum_gain
83
+
84
+ self .column_index = column
85
+ self .threshold = value
86
+ self .impurity = gain
87
+
88
+ # Split dataset
89
+ left_X , right_X , left_target , right_target = split_dataset (X , target , column , value )
90
+
91
+ # Grow left and right child
92
+ self .left_child = Tree (self .regression , self .criterion , self .n_classes )
93
+ self .left_child ._train (
94
+ left_X , left_target , max_features , min_samples_split , max_depth - 1 , minimum_gain
95
+ )
96
+
97
+ self .right_child = Tree (self .regression , self .criterion , self .n_classes )
98
+ self .right_child ._train (
99
+ right_X , right_target , max_features , min_samples_split , max_depth - 1 , minimum_gain
100
+ )
101
+ except AssertionError :
102
+ self ._calculate_leaf_value (target )
103
+
67
104
def train (self , X , target , max_features = None , min_samples_split = 10 , max_depth = None , minimum_gain = 0.01 , loss = None ):
68
105
"""Build a decision tree from training set.
69
106
@@ -93,40 +130,12 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
93
130
if loss is not None :
94
131
self .loss = loss
95
132
96
- try :
97
- # Exit from recursion using assert syntax
98
- assert X .shape [0 ] > min_samples_split
99
- assert max_depth > 0
100
-
101
- if max_features is None :
102
- max_features = X .shape [1 ]
133
+ if not self .regression :
134
+ self .n_classes = len (np .unique (target ['y' ]))
103
135
104
- column , value , gain = self ._find_best_split (X , target , max_features )
105
- assert gain is not None
106
- if self .regression :
107
- assert gain != 0
108
- else :
109
- assert gain > minimum_gain
136
+ self ._train (X , target , max_features = max_features , min_samples_split = min_samples_split ,
137
+ max_depth = max_depth , minimum_gain = minimum_gain )
110
138
111
- self .column_index = column
112
- self .threshold = value
113
- self .impurity = gain
114
-
115
- # Split dataset
116
- left_X , right_X , left_target , right_target = split_dataset (X , target , column , value )
117
-
118
- # Grow left and right child
119
- self .left_child = Tree (self .regression , self .criterion )
120
- self .left_child .train (
121
- left_X , left_target , max_features , min_samples_split , max_depth - 1 , minimum_gain , loss
122
- )
123
-
124
- self .right_child = Tree (self .regression , self .criterion )
125
- self .right_child .train (
126
- right_X , right_target , max_features , min_samples_split , max_depth - 1 , minimum_gain , loss
127
- )
128
- except AssertionError :
129
- self ._calculate_leaf_value (target )
130
139
131
140
def _calculate_leaf_value (self , targets ):
132
141
"""Find optimal value for leaf."""
@@ -140,7 +149,7 @@ def _calculate_leaf_value(self, targets):
140
149
self .outcome = np .mean (targets ["y" ])
141
150
else :
142
151
# Probability for classification task
143
- self .outcome = stats . itemfreq (targets ["y" ])[:, 1 ] / float ( targets ["y" ].shape [0 ])
152
+ self .outcome = np . bincount (targets ["y" ], minlength = self . n_classes ) / targets ["y" ].shape [0 ]
144
153
145
154
def predict_row (self , row ):
146
155
"""Predict single row."""
0 commit comments