Skip to content

Commit 0b9ebf3

Browse files
resolved classification bug in decision tree using n_classes
1 parent 58a8c25 commit 0b9ebf3

File tree

3 files changed

+29
-12
lines changed

3 files changed

+29
-12
lines changed

examples/random_forest.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from timeit import default_timer
2+
start = default_timer()
13
import logging
24

3-
from sklearn.datasets import make_classification
5+
import numpy as np
6+
from sklearn.datasets import make_classification, load_boston, load_digits, load_breast_cancer, load_iris
47
from sklearn.datasets import make_regression
5-
from sklearn.metrics import roc_auc_score
8+
from sklearn.metrics import roc_auc_score, accuracy_score
69

710
try:
811
from sklearn.model_selection import train_test_split
@@ -20,13 +23,15 @@ def classification():
2023
X, y = make_classification(
2124
n_samples=500, n_features=10, n_informative=10, random_state=1111, n_classes=2, class_sep=2.5, n_redundant=0
2225
)
26+
#X,y = load_breast_cancer(return_X_y=True)
2327

2428
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.15, random_state=1111)
2529

26-
model = RandomForestClassifier(n_estimators=10, max_depth=4)
30+
model = RandomForestClassifier(n_estimators=5, max_depth=4)
2731
model.fit(X_train, y_train)
28-
predictions = model.predict(X_test)[:, 1]
29-
# print(predictions)
32+
predictions = model.predict(X_test)[:,1]
33+
#predictions = np.argmax(model.predict(X_test),axis=1)
34+
print(predictions.shape)
3035
print("classification, roc auc score: %s" % roc_auc_score(y_test, predictions))
3136

3237

@@ -46,3 +51,5 @@ def regression():
4651
if __name__ == "__main__":
4752
classification()
4853
# regression()
54+
end = default_timer()
55+
print(end-start)

mla/ensemble/random_forest.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,15 @@ def fit(self, X, y):
3939
self._train()
4040

4141
def _train(self):
42+
n_classes = None if self.trees[0].regression else len(np.unique(self.y))
4243
for tree in self.trees:
4344
tree.train(
4445
self.X,
4546
self.y,
4647
max_features=self.max_features,
4748
min_samples_split=self.min_samples_split,
4849
max_depth=self.max_depth,
50+
n_classes=n_classes
4951
)
5052

5153
def _predict(self, X=None):
@@ -78,10 +80,14 @@ def _predict(self, X=None):
7880
for i in range(X.shape[0]):
7981
row_pred = np.zeros(y_shape)
8082
for tree in self.trees:
81-
row_pred += tree.predict_row(X[i, :])
83+
tmp = tree.predict_row(X[i, :])
84+
print(tmp,row_pred.shape,row_pred)
85+
row_pred += tmp
86+
8287

8388
row_pred /= self.n_estimators
8489
predictions[i, :] = row_pred
90+
print(f"i={i},{row_pred}\n")
8591
return predictions
8692

8793

mla/ensemble/tree.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def _find_best_split(self, X, target, n_features):
6464
max_col, max_val, max_gain = column, value, gain
6565
return max_col, max_val, max_gain
6666

67-
def train(self, X, target, max_features=None, min_samples_split=10, max_depth=None, minimum_gain=0.01, loss=None):
67+
def train(self, X, target, max_features=None, min_samples_split=10, max_depth=None,
68+
minimum_gain=0.01, loss=None, n_classes = None):
6869
"""Build a decision tree from training set.
6970
7071
Parameters
@@ -84,6 +85,8 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
8485
Minimum gain required for splitting.
8586
loss : function, default None
8687
Loss function for gradient boosting.
88+
n_classes : int, default None
89+
No of unique labels in case of classification
8790
"""
8891

8992
if not isinstance(target, dict):
@@ -118,17 +121,17 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
118121
# Grow left and right child
119122
self.left_child = Tree(self.regression, self.criterion)
120123
self.left_child.train(
121-
left_X, left_target, max_features, min_samples_split, max_depth - 1, minimum_gain, loss
124+
left_X, left_target, max_features, min_samples_split, max_depth - 1, minimum_gain, loss, n_classes
122125
)
123126

124127
self.right_child = Tree(self.regression, self.criterion)
125128
self.right_child.train(
126-
right_X, right_target, max_features, min_samples_split, max_depth - 1, minimum_gain, loss
129+
right_X, right_target, max_features, min_samples_split, max_depth - 1, minimum_gain, loss, n_classes
127130
)
128131
except AssertionError:
129-
self._calculate_leaf_value(target)
132+
self._calculate_leaf_value(target, n_classes)
130133

131-
def _calculate_leaf_value(self, targets):
134+
def _calculate_leaf_value(self, targets, n_classes):
132135
"""Find optimal value for leaf."""
133136
if self.loss is not None:
134137
# Gradient boosting
@@ -140,7 +143,8 @@ def _calculate_leaf_value(self, targets):
140143
self.outcome = np.mean(targets["y"])
141144
else:
142145
# Probability for classification task
143-
self.outcome = stats.itemfreq(targets["y"])[:, 1] / float(targets["y"].shape[0])
146+
#self.outcome = stats.itemfreq(targets["y"])[:, 1] / float(targets["y"].shape[0])
147+
self.outcome = np.bincount(targets["y"], minlength=n_classes) / targets["y"].shape[0]
144148

145149
def predict_row(self, row):
146150
"""Predict single row."""

0 commit comments

Comments
 (0)