Skip to content

Commit 492fc6b

Browse files
author
Andrey Oskin
committed
Removed LightElkan and updated MLJ
1 parent d9e06e2 commit 492fc6b

File tree

9 files changed

+42
-309
lines changed

9 files changed

+42
-309
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParallelKMeans"
22
uuid = "42b8e9d4-006b-409a-8472-7f34b3fb58af"
33
authors = ["Bernard Brenyah", "Andrey Oskin"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"

src/ParallelKMeans.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,12 @@ import Distances
77

88
include("seeding.jl")
99
include("kmeans.jl")
10-
include("light_elkan.jl")
1110
include("lloyd.jl")
1211
include("hamerly.jl")
1312
include("elkan.jl")
1413
include("mlj_interface.jl")
1514

1615
export kmeans
17-
export Lloyd, LightElkan, Hamerly, Elkan
16+
export Lloyd, Hamerly, Elkan
1817

1918
end # module

src/elkan.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@ function kmeans!(alg::Elkan, containers, X, k;
2929
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X)
3030

3131
converged = false
32-
niters = 1
32+
niters = 0
3333
J_previous = 0.0
3434

3535
# Update centroids & labels with closest members until convergence
36-
while niters <= max_iters
36+
while niters < max_iters
37+
niters += 1
3738
# Core iteration
3839
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X)
3940

@@ -66,7 +67,6 @@ function kmeans!(alg::Elkan, containers, X, k;
6667
# Step 1 in original paper, calulation of distance d(c, c')
6768
update_containers(alg, containers, centroids, n_threads)
6869
J_previous = J
69-
niters += 1
7070
end
7171

7272
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)

src/hamerly.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ function kmeans!(alg::Hamerly, containers, X, k;
5757
end
5858

5959
J_previous = J
60-
6160
end
6261

6362
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids)

src/light_elkan.jl

Lines changed: 0 additions & 191 deletions
This file was deleted.

src/lloyd.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ function chunk_update_centroids(::Lloyd, containers, centroids, X, r, idx)
125125
containers.J[idx] = J
126126
end
127127

128-
function collect_containers(alg::T, containers, centroids, n_threads) where {T <: Union{LightElkan, Lloyd}}
128+
function collect_containers(alg::Lloyd, containers, centroids, n_threads)
129129
if n_threads == 1
130130
@inbounds centroids .= containers.centroids_new[1] ./ containers.centroids_cnt[1]'
131131
else

src/mlj_interface.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ const ParallelKMeans_Desc = "Parallel & lightning fast implementation of all ava
55
# availalbe variants for reference
66
const MLJDICT = Dict(:Lloyd => Lloyd(),
77
:Hamerly => Hamerly(),
8-
:LightElkan => LightElkan())
8+
:Elkan => Elkan())
99

1010
####
1111
#### MODEL DEFINITION
@@ -24,7 +24,7 @@ mutable struct KMeans <: MLJModelInterface.Unsupervised
2424
end
2525

2626

27-
function KMeans(; algo=:Lloyd, k_init="k-means++",
27+
function KMeans(; algo=:Hamerly, k_init="k-means++",
2828
k=3, tol=1e-6, max_iters=300, copy=true,
2929
threads=Threads.nthreads(), verbosity=0, init=nothing)
3030

@@ -39,12 +39,12 @@ function MLJModelInterface.clean!(m::KMeans)
3939
warning = ""
4040

4141
if !(m.algo keys(MLJDICT))
42-
warning *= "Unsupported KMeans variant, Defauting to KMeans++ seeding algorithm."
43-
m.algo = :Lloyd
42+
warning *= "Unsupported KMeans variant, Defaulting to Hamerly algorithm."
43+
m.algo = :Hamerly
4444

4545
elseif m.k_init != "k-means++"
46-
warning *= "Only `k-means++` or random seeding algorithms are supported. Defaulting to random seeding."
47-
m.k_init = "random"
46+
warning *= "Only `k-means++` or random seeding algorithms are supported. Defaulting to k-means++ seeding."
47+
m.k_init = "kmeans++"
4848

4949
elseif m.k < 1
5050
warning *= "Number of clusters must be greater than 0. Defaulting to 3 clusters."
@@ -63,8 +63,8 @@ function MLJModelInterface.clean!(m::KMeans)
6363
m.threads = Threads.nthreads()
6464

6565
elseif !(m.verbosity (0, 1))
66-
warning *= "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 0."
67-
m.verbosity = 0
66+
warning *= "Verbosity must be either 0 (no info) or 1 (info requested). Defaulting to 1."
67+
m.verbosity = 1
6868
end
6969
return warning
7070
end

test/test04_elkan.jl

Lines changed: 1 addition & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,9 @@
11
module TestElkan
22

33
using ParallelKMeans
4-
using ParallelKMeans: update_containers
54
using Test
65
using Random
76

8-
@testset "centroid distances" begin
9-
containers = (centroids_dist = Matrix{Float64}(undef, 3, 3), )
10-
centroids = [1.0 2.0 4.0; 2.0 1.0 3.0]
11-
update_containers(LightElkan(), containers, centroids, 1)
12-
centroids_dist = containers.centroids_dist
13-
@test centroids_dist[1, 2] == centroids_dist[2, 1]
14-
@test centroids_dist[1, 3] == centroids_dist[3, 1]
15-
@test centroids_dist[2, 3] == centroids_dist[3, 2]
16-
@test centroids_dist[1, 2] == 0.5
17-
@test centroids_dist[1, 3] == 2.5
18-
@test centroids_dist[2, 3] == 2.0
19-
@test centroids_dist[1, 1] == 0.5
20-
@test centroids_dist[2, 2] == 0.5
21-
@test centroids_dist[3, 3] == 2.0
22-
end
23-
24-
@testset "basic kmeans light elkan" begin
25-
X = [1. 2. 4.;]
26-
res = kmeans(LightElkan(), X, 1; n_threads = 1, tol = 1e-6, verbose = false)
27-
@test res.assignments == [1, 1, 1]
28-
@test res.centers[1] 2.3333333333333335
29-
@test res.totalcost 4.666666666666666
30-
@test res.converged
31-
32-
res = kmeans(LightElkan(), X, 2; n_threads = 1, init = [1.0 4.0], tol = 1e-6, verbose = false)
33-
@test res.assignments == [1, 1, 2]
34-
@test res.centers [1.5 4.0]
35-
@test res.totalcost 0.5
36-
@test res.converged
37-
end
38-
39-
@testset "no convergence yield last result" begin
40-
X = [1. 2. 4.;]
41-
res = kmeans(LightElkan(), X, 2; n_threads = 1, init = [1.0 4.0], tol = 1e-6, max_iters = 1, verbose = false)
42-
@test !res.converged
43-
@test res.totalcost 0.5
44-
end
45-
46-
@testset "singlethread linear separation" begin
47-
Random.seed!(2020)
48-
49-
X = rand(3, 100)
50-
res = kmeans(LightElkan(), X, 3; n_threads = 1, tol = 1e-6, verbose = false)
51-
52-
@test res.totalcost 14.16198704459199
53-
@test res.converged
54-
@test res.iterations == 11
55-
end
56-
57-
@testset "multithread linear separation quasi two threads" begin
58-
Random.seed!(2020)
59-
60-
X = rand(3, 100)
61-
res = kmeans(LightElkan(), X, 3; n_threads = 2, tol = 1e-6, verbose = false)
62-
63-
@test res.totalcost 14.16198704459199
64-
@test res.converged
65-
end
66-
677
@testset "basic kmeans elkan" begin
688
X = [1. 2. 4.;]
699
res = kmeans(Elkan(), X, 1; n_threads = 1, tol = 1e-6, verbose = false)
@@ -94,7 +34,7 @@ end
9434

9535
@test res.totalcost 14.16198704459199
9636
@test !res.converged
97-
@test res.iterations == 11
37+
@test res.iterations == 10
9838
end
9939

10040
@testset "elkan multithread linear separation quasi two threads" begin

0 commit comments

Comments
 (0)