Skip to content

Commit 33d2662

Browse files
committed
Ensemble methods: Optimize split search, update comments
1 parent 2e4398d commit 33d2662

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

mla/ensemble/gbm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def transform(self, pred):
3838
return pred
3939

4040
def gain(self, actual, predicted):
41-
"""Gain for split finding."""
41+
"""Calculate gain for split search."""
4242
nominator = self.grad(actual, predicted).sum() ** 2
4343
denominator = (self.hess(actual, predicted).sum() + self.regularization)
4444
return 0.5 * (nominator / denominator)
@@ -70,7 +70,7 @@ def transform(self, output):
7070

7171

7272
class GradientBoosting(BaseEstimator):
73-
"""Gradient boosting trees with taylor expansion approximation (as in xgboost)."""
73+
"""Gradient boosting trees with Taylor's expansion approximation (as in xgboost)."""
7474

7575
def __init__(self, n_estimators, learning_rate=0.1, max_features=10, max_depth=2, min_samples_split=10):
7676
self.min_samples_split = min_samples_split

mla/ensemble/tree.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,17 @@ def __init__(self, regression=False, criterion=None):
2727
def is_terminal(self):
2828
return not bool(self.left_child and self.right_child)
2929

30-
def _find_splits(self, X, y):
30+
def _find_splits(self, X):
3131
"""Find all possible split values."""
32+
split_values = set()
3233

33-
# Sort feature set
34-
df = np.rec.fromarrays([X, y], names='x,y')
35-
df.sort(order='x')
34+
# Get unique values in a sorted order
35+
x_unique = list(np.unique(X))
36+
for i in range(1, len(x_unique)):
37+
# Find a point between two values
38+
average = (x_unique[i - 1] + x_unique[i]) / 2.0
39+
split_values.add(average)
3640

37-
split_values = set()
38-
for i in range(1, X.shape[0]):
39-
if df.y[i - 1] != df.y[i]:
40-
average = (df.x[i - 1] + df.x[i]) / 2.0
41-
split_values.add(average)
4241
return list(split_values)
4342

4443
def _find_best_split(self, X, target, n_features):
@@ -49,7 +48,7 @@ def _find_best_split(self, X, target, n_features):
4948
max_gain, max_col, max_val = None, None, None
5049

5150
for column in subset:
52-
split_values = self._find_splits(X[:, column], target['y'])
51+
split_values = self._find_splits(X[:, column])
5352
for value in split_values:
5453
if self.loss is None:
5554
# Random forest
@@ -112,6 +111,7 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
112111
# Split dataset
113112
left_X, right_X, left_target, right_target = split_dataset(X, target, column, value)
114113

114+
# Grow left and right child
115115
self.left_child = Tree(self.regression, self.criterion)
116116
self.left_child.train(left_X, left_target, max_features, min_samples_split, max_depth - 1,
117117
minimum_gain, loss)
@@ -137,6 +137,7 @@ def _calculate_leaf_value(self, targets):
137137
self.outcome = stats.itemfreq(targets['y'])[:, 1] / float(targets['y'].shape[0])
138138

139139
def predict_row(self, row):
140+
"""Predict single row."""
140141
if not self.is_terminal:
141142
if row[self.column_index] < self.threshold:
142143
return self.left_child.predict_row(row)

0 commit comments

Comments
 (0)