Skip to content

Commit 1c0b617

Browse files
committed
Update comments, fix KMeans estimator
1 parent 13913fd commit 1c0b617

File tree

4 files changed

+7
-5
lines changed

4 files changed

+7
-5
lines changed

mla/ensemble/random_forest.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class RandomForest(BaseEstimator):
1010
def __init__(self, n_estimators=10, max_features=None, min_samples_split=10, max_depth=None, criterion=None):
11-
"""
11+
"""Base class for RandomForest.
1212
1313
Parameters
1414
----------
@@ -42,6 +42,9 @@ def _train(self):
4242
tree.train(self.X, self.y, max_features=self.max_features, min_samples_split=self.min_samples_split,
4343
max_depth=self.max_depth)
4444

45+
def _predict(self, X=None):
46+
raise NotImplementedError()
47+
4548

4649
class RandomForestClassifier(RandomForest):
4750
def __init__(self, n_estimators=10, max_features=None, min_samples_split=10, max_depth=None, criterion='entropy'):

mla/ensemble/tree.py

+2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def train(self, X, target, max_features=None, min_samples_split=10, max_depth=No
8282
Maximum depth of the tree.
8383
minimum_gain : float, default 0.01
8484
Minimum gain required for splitting.
85+
loss : function, default None
86+
Loss function for gradient boosting.
8587
"""
8688

8789
if not isinstance(target, dict):

mla/kmeans.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _initialize_cetroids(self, init):
5959
else:
6060
raise ValueError('Unknown type of init parameter')
6161

62-
def predict(self):
62+
def _predict(self, X=None):
6363
"""Perform the clustering on the dataset."""
6464
self._initialize_cetroids(self.init)
6565
centroids = self.centroids

mla/metrics/tests/.cache/v/cache/lastfailed

-3
This file was deleted.

0 commit comments

Comments
 (0)