@@ -27,18 +27,17 @@ def __init__(self, regression=False, criterion=None):
27
27
def is_terminal (self ):
28
28
return not bool (self .left_child and self .right_child )
29
29
30
- def _find_splits (self , X , y ):
30
+ def _find_splits (self , X ):
31
31
"""Find all possible split values."""
32
+ split_values = set ()
32
33
33
- # Sort feature set
34
- df = np .rec .fromarrays ([X , y ], names = 'x,y' )
35
- df .sort (order = 'x' )
34
+ # Get unique values in a sorted order
35
+ x_unique = list (np .unique (X ))
36
+ for i in range (1 , len (x_unique )):
37
+ # Find a point between two values
38
+ average = (x_unique [i - 1 ] + x_unique [i ]) / 2.0
39
+ split_values .add (average )
36
40
37
- split_values = set ()
38
- for i in range (1 , X .shape [0 ]):
39
- if df .y [i - 1 ] != df .y [i ]:
40
- average = (df .x [i - 1 ] + df .x [i ]) / 2.0
41
- split_values .add (average )
42
41
return list (split_values )
43
42
44
43
def _find_best_split (self , X , target , n_features ):
@@ -49,7 +48,7 @@ def _find_best_split(self, X, target, n_features):
49
48
max_gain , max_col , max_val = None , None , None
50
49
51
50
for column in subset :
52
- split_values = self ._find_splits (X [:, column ], target [ 'y' ] )
51
+ split_values = self ._find_splits (X [:, column ])
53
52
for value in split_values :
54
53
if self .loss is None :
55
54
# Random forest
@@ -112,6 +111,7 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
112
111
# Split dataset
113
112
left_X , right_X , left_target , right_target = split_dataset (X , target , column , value )
114
113
114
+ # Grow left and right child
115
115
self .left_child = Tree (self .regression , self .criterion )
116
116
self .left_child .train (left_X , left_target , max_features , min_samples_split , max_depth - 1 ,
117
117
minimum_gain , loss )
@@ -137,6 +137,7 @@ def _calculate_leaf_value(self, targets):
137
137
self .outcome = stats .itemfreq (targets ['y' ])[:, 1 ] / float (targets ['y' ].shape [0 ])
138
138
139
139
def predict_row (self , row ):
140
+ """Predict single row."""
140
141
if not self .is_terminal :
141
142
if row [self .column_index ] < self .threshold :
142
143
return self .left_child .predict_row (row )
0 commit comments