Skip to content

Commit 54c8d37

Browse files
author
Andrey Oskin
committed
Removed dead code and updated docs
1 parent 492fc6b commit 54c8d37

File tree

5 files changed

+22
-128
lines changed

5 files changed

+22
-128
lines changed

docs/src/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# ParallelKMeans.jl Package
1+
# [ParallelKMeans.jl Package](https://github.com/PyDataBlog/ParallelKMeans.jl)
22

33
```@contents
44
Depth = 4
@@ -59,7 +59,7 @@ git checkout experimental
5959

6060
- [X] Implementation of [Hamerly implementation](https://www.researchgate.net/publication/220906984_Making_k-means_Even_Faster).
6161
- [X] Interface for inclusion in Alan Turing Institute's [MLJModels](https://github.com/alan-turing-institute/MLJModels.jl#who-is-this-repo-for).
62-
- [ ] Full Implementation of Triangle inequality based on [Elkan - 2003 Using the Triangle Inequality to Accelerate K-Means"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf).
62+
- [X] Full Implementation of Triangle inequality based on [Elkan - 2003 Using the Triangle Inequality to Accelerate K-Means"](https://www.aaai.org/Papers/ICML/2003/ICML03-022.pdf).
6363
- [ ] Implementation of [Geometric methods to accelerate k-means algorithm](http://cs.baylor.edu/~hamerly/papers/sdm2016_rysavy_hamerly.pdf).
6464
- [ ] Native support for tabular data inputs outside of MLJModels' interface.
6565
- [ ] Refactoring and finalizaiton of API desgin.
@@ -177,6 +177,7 @@ ________________________________________________________________________________
177177

178178
- 0.1.0 Initial release
179179
- 0.1.1 Added interface for MLJ
180+
- 0.1.2 Added Elkan algorithm
180181

181182
## Contributing
182183

src/ParallelKMeans.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
module ParallelKMeans
22

33
using StatsBase
4-
using MLJModelInterface
4+
import MLJModelInterface
55
import Base.Threads: @spawn
66
import Distances
77

8+
const MMI = MLJModelInterface
9+
810
include("seeding.jl")
911
include("kmeans.jl")
1012
include("lloyd.jl")

src/kmeans.jl

Lines changed: 0 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -109,18 +109,6 @@ design matrix(x), centroids (centre), and the number of desired groups (k).
109109
110110
A Float type representing the computed metric is returned.
111111
"""
112-
function sum_of_squares(x, labels, centre)
113-
s = 0.0
114-
115-
@inbounds for j in axes(x, 2)
116-
for i in axes(x, 1)
117-
s += (x[i, j] - centre[i, labels[j]])^2
118-
end
119-
end
120-
121-
return s
122-
end
123-
124112
function sum_of_squares(containers, x, labels, centre, r, idx)
125113
s = 0.0
126114

@@ -171,100 +159,3 @@ function kmeans(alg, design_matrix, k;
171159
k_init = k_init, max_iters = max_iters, tol = tol,
172160
verbose = verbose, init = init)
173161
end
174-
175-
176-
"""
177-
Kmeans!(alg::AbstractKMeansAlg, containers, design_matrix, k; n_threads = nthreads(), k_init="k-means++", max_iters=300, tol=1e-6, verbose=false)
178-
179-
Mutable version of `kmeans` function. Definition of arguments and results can be
180-
found in `kmeans`.
181-
182-
Argument `containers` represent algorithm specific containers, such as labels, intermidiate
183-
centroids and so on, which are used during calculations.
184-
"""
185-
function kmeans!(alg, containers, design_matrix, k;
186-
n_threads = Threads.nthreads(),
187-
k_init = "k-means++", max_iters = 300,
188-
tol = 1e-6, verbose = false, init = nothing)
189-
nrow, ncol = size(design_matrix)
190-
centroids = init == nothing ? smart_init(design_matrix, k, n_threads, init=k_init).centroids : deepcopy(init)
191-
192-
converged = false
193-
niters = 0
194-
J_previous = 0.0
195-
196-
# Update centroids & labels with closest members until convergence
197-
198-
while niters < max_iters
199-
niters += 1
200-
201-
update_containers!(containers, alg, centroids, n_threads)
202-
J = update_centroids!(centroids, containers, alg, design_matrix, n_threads)
203-
204-
if verbose
205-
# Show progress and terminate if J stopped decreasing.
206-
println("Iteration $niters: Jclust = $J")
207-
end
208-
209-
# Check for convergence
210-
if (niters > 1) & (abs(J - J_previous) < (tol * J))
211-
converged = true
212-
break
213-
end
214-
215-
J_previous = J
216-
217-
end
218-
219-
totalcost = sum_of_squares(design_matrix, containers.labels, centroids)
220-
221-
# Terminate algorithm with the assumption that K-means has converged
222-
if verbose & converged
223-
println("Successfully terminated with convergence.")
224-
end
225-
226-
# TODO empty placeholder vectors should be calculated
227-
# TODO Float64 type definitions is too restrictive, should be relaxed
228-
# especially during GPU related development
229-
return KmeansResult(centroids, containers.labels, Float64[], Int[], Float64[], totalcost, niters, converged)
230-
end
231-
232-
"""
233-
update_centroids!(centroids, containers, alg, design_matrix, n_threads)
234-
235-
Internal function, used to update centroids by utilizing one of `alg`. It works as
236-
a wrapper of internal `chunk_update_centroids!` function, splitting incoming
237-
`design_matrix` in chunks and combining results together.
238-
"""
239-
function update_centroids!(centroids, containers, alg, design_matrix, n_threads)
240-
ncol = size(design_matrix, 2)
241-
242-
if n_threads == 1
243-
r = axes(design_matrix, 2)
244-
J = chunk_update_centroids!(centroids, containers, alg, design_matrix, r, 1)
245-
246-
centroids .= containers.new_centroids[1] ./ containers.centroids_cnt[1]'
247-
else
248-
ranges = splitter(ncol, n_threads)
249-
250-
waiting_list = Vector{Task}(undef, n_threads - 1)
251-
252-
for i in 1:length(ranges) - 1
253-
waiting_list[i] = @spawn chunk_update_centroids!(centroids, containers,
254-
alg, design_matrix, ranges[i], i + 1)
255-
end
256-
257-
J = chunk_update_centroids!(centroids, containers, alg, design_matrix, ranges[end], 1)
258-
259-
J += sum(fetch.(waiting_list))
260-
261-
for i in 1:length(ranges) - 1
262-
containers.new_centroids[1] .+= containers.new_centroids[i + 1]
263-
containers.centroids_cnt[1] .+= containers.centroids_cnt[i + 1]
264-
end
265-
266-
centroids .= containers.new_centroids[1] ./ containers.centroids_cnt[1]'
267-
end
268-
269-
return J
270-
end

src/mlj_interface.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ const MLJDICT = Dict(:Lloyd => Lloyd(),
1111
#### MODEL DEFINITION
1212
####
1313

14-
mutable struct KMeans <: MLJModelInterface.Unsupervised
14+
mutable struct KMeans <: MMI.Unsupervised
1515
algo::Symbol
1616
k_init::String
1717
k::Int
@@ -29,13 +29,13 @@ function KMeans(; algo=:Hamerly, k_init="k-means++",
2929
threads=Threads.nthreads(), verbosity=0, init=nothing)
3030

3131
model = KMeans(algo, k_init, k, tol, max_iters, copy, threads, verbosity, init)
32-
message = MLJModelInterface.clean!(model)
32+
message = MMI.clean!(model)
3333
isempty(message) || @warn message
3434
return model
3535
end
3636

3737

38-
function MLJModelInterface.clean!(m::KMeans)
38+
function MMI.clean!(m::KMeans)
3939
warning = ""
4040

4141
if !(m.algo keys(MLJDICT))
@@ -78,14 +78,14 @@ end
7878
7979
See also the [package documentation](https://pydatablog.github.io/ParallelKMeans.jl/stable).
8080
"""
81-
function MLJModelInterface.fit(m::KMeans, X)
81+
function MMI.fit(m::KMeans, X)
8282
# convert tabular input data into the matrix model expects. Column assumed as features so input data is permuted
8383
if !m.copy
8484
# permutes dimensions of input table without copying and pass to model
85-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(X)')
85+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(X)')
8686
else
8787
# permutes dimensions of input table as a column major matrix from a copy of the data
88-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(X, transpose=true))
88+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(X, transpose=true))
8989
end
9090

9191
# lookup available algorithms
@@ -106,7 +106,7 @@ function MLJModelInterface.fit(m::KMeans, X)
106106
end
107107

108108

109-
function MLJModelInterface.fitted_params(model::KMeans, fitresult)
109+
function MMI.fitted_params(model::KMeans, fitresult)
110110
# extract what's relevant from `fitresult`
111111
results, _, _ = fitresult # unpack fitresult
112112
centers = results.centers
@@ -124,15 +124,15 @@ end
124124
#### PREDICT FUNCTION
125125
####
126126

127-
function MLJModelInterface.transform(m::KMeans, fitresult, Xnew)
127+
function MMI.transform(m::KMeans, fitresult, Xnew)
128128
# make predictions/assignments using the learned centroids
129129

130130
if !m.copy
131131
# permutes dimensions of input table without copying and pass to model
132-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(Xnew)')
132+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(Xnew)')
133133
else
134134
# permutes dimensions of input table as a column major matrix from a copy of the data
135-
DMatrix = convert(Array{Float64, 2}, MLJModelInterface.matrix(Xnew, transpose=true))
135+
DMatrix = convert(Array{Float64, 2}, MMI.matrix(Xnew, transpose=true))
136136
end
137137

138138
# TODO: Warn users if fitresult is from a `non-converged` fit?
@@ -147,7 +147,7 @@ function MLJModelInterface.transform(m::KMeans, fitresult, Xnew)
147147
centroids = results.centers
148148
distances = Distances.pairwise(Distances.SqEuclidean(), DMatrix, centroids; dims=2)
149149
preds = argmin.(eachrow(distances))
150-
return MLJModelInterface.table(reshape(preds, :, 1), prototype=Xnew)
150+
return MMI.table(reshape(preds, :, 1), prototype=Xnew)
151151
end
152152

153153

@@ -156,7 +156,7 @@ end
156156
####
157157

158158
# TODO 4: metadata for the package and for each of the model interfaces
159-
metadata_pkg.(KMeans,
159+
MMI.metadata_pkg.(KMeans,
160160
name = "ParallelKMeans",
161161
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af",
162162
url = "https://github.com/PyDataBlog/ParallelKMeans.jl",
@@ -166,9 +166,9 @@ metadata_pkg.(KMeans,
166166

167167

168168
# Metadata for ParaKMeans model interface
169-
metadata_model(KMeans,
170-
input = MLJModelInterface.Table(MLJModelInterface.Continuous),
171-
output = MLJModelInterface.Table(MLJModelInterface.Count),
169+
MMI.metadata_model(KMeans,
170+
input = MMI.Table(MMI.Continuous),
171+
output = MMI.Table(MMI.Count),
172172
weights = false,
173173
descr = ParallelKMeans_Desc,
174174
path = "ParallelKMeans.KMeans")

test/test02_kmeans.jl renamed to test/test02_lloyd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
module TestKMeans
1+
module TestLloyd
22

33
using ParallelKMeans
44
using Test

0 commit comments

Comments
 (0)