Skip to content

Commit 2c15265

Browse files
committed
Small fixes for k-means, Naive bayes
1 parent 33d2662 commit 2c15265

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

mla/kmeans.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,11 @@ def _initialize_cetroids(self, init):
6060
raise ValueError('Unknown type of init parameter')
6161

6262
def _predict(self, X=None):
63-
"""Perform the clustering on the dataset."""
63+
"""Perform clustering on the dataset."""
6464
self._initialize_cetroids(self.init)
6565
centroids = self.centroids
66+
67+
# Optimize clusters
6668
for _ in range(self.max_iters):
6769
self._assign(centroids)
6870
centroids_old = centroids
@@ -95,6 +97,7 @@ def _assign(self, centroids):
9597
self.clusters[closest].append(row)
9698

9799
def _closest(self, fpoint, centroids):
100+
"""Find the closest centroid for a point."""
98101
closest_index = None
99102
closest_distance = None
100103
for i, point in enumerate(centroids):
@@ -109,6 +112,7 @@ def _get_centroid(self, cluster):
109112
return [np.mean(np.take(self.X[:, i], cluster)) for i in range(self.n_features)]
110113

111114
def _dist_from_centers(self):
115+
"""Calculate distance from centers."""
112116
return np.array([min([euclidean_distance(x, c) for c in self.centroids]) for x in self.X])
113117

114118
def _choose_next_center(self):
@@ -120,7 +124,11 @@ def _choose_next_center(self):
120124
return self.X[ind]
121125

122126
def _is_converged(self, centroids_old, centroids):
123-
return True if sum([euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]) == 0 else False
127+
"""Check if the distance between old and new centroids is zero."""
128+
distance = 0
129+
for i in range(self.K):
130+
distance += euclidean_distance(centroids_old[i], centroids[i])
131+
return distance == 0
124132

125133
def plot(self, data=None):
126134
sns.set(style="white")

mla/naive_bayes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ def fit(self, X, y=None):
2828

2929
def _predict(self, X=None):
3030
# Apply _predict_proba for each row
31-
predictions = np.apply_along_axis(self._predict_proba, 1, X)
31+
predictions = np.apply_along_axis(self._predict_row, 1, X)
3232
# Normalize probabilities
3333
return softmax(predictions)
3434

35-
def _predict_proba(self, x):
35+
def _predict_row(self, x):
3636
"""Predict log likelihood for given row."""
3737
output = []
3838
for y in range(self.n_classes):

0 commit comments

Comments
 (0)