Skip to content

Commit 1bb3db8

Browse files
authored
Add KMedoids (#298)
* Initial implementation of KMedoids * Add basic test for KMedoids * Add KMedoids to docs * Add more tests for KMedoids * Use existing _nrows utility * Use _assert utility function * Retrieve distance type * Minor adjustments
1 parent 1de5b52 commit 1bb3db8

File tree

7 files changed

+181
-2
lines changed

7 files changed

+181
-2
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
1818
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1919
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2020
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
21+
TableDistances = "e5d66e97-8c70-46bb-8b66-04a2d73ad782"
2122
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2223
TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8"
2324
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
@@ -37,6 +38,7 @@ PrettyTables = "2"
3738
Random = "1.9"
3839
Statistics = "1.9"
3940
StatsBase = "0.33, 0.34"
41+
TableDistances = "1.0"
4042
Tables = "1.6"
4143
TransformsBase = "1.5"
4244
Unitful = "1.17"

docs/src/transforms.md

+6
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,12 @@ SDS
242242
ProjectionPursuit
243243
```
244244

245+
## KMedoids
246+
247+
```@docs
248+
KMedoids
249+
```
250+
245251
## Closure
246252

247253
```@docs

src/TableTransforms.jl

+4-2
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ module TableTransforms
66

77
using Tables
88
using Unitful
9-
using Statistics
109
using PrettyTables
1110
using AbstractTrees
12-
using LinearAlgebra
11+
using TableDistances
1312
using DataScienceTraits
1413
using CategoricalArrays
14+
using LinearAlgebra
15+
using Statistics
1516
using Random
1617
using CoDa
1718

@@ -90,6 +91,7 @@ export
9091
DRS,
9192
SDS,
9293
ProjectionPursuit,
94+
KMedoids,
9395
Closure,
9496
Remainder,
9597
Compose,

src/transforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ include("transforms/quantile.jl")
286286
include("transforms/functional.jl")
287287
include("transforms/eigenanalysis.jl")
288288
include("transforms/projectionpursuit.jl")
289+
include("transforms/kmedoids.jl")
289290
include("transforms/closure.jl")
290291
include("transforms/remainder.jl")
291292
include("transforms/compose.jl")

src/transforms/kmedoids.jl

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
"""
6+
KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng())
7+
8+
Assign labels to rows of table using the `k`-medoids algorithm.
9+
10+
The iterative algorithm is interrupted if the relative change on
11+
the average distance to medoids is smaller than a tolerance `tol`
12+
or if the number of iterations exceeds the maximum number of
13+
iterations `maxiter`.
14+
15+
Optionally, specify a dictionary of `weights` for each column to
16+
affect the underlying table distance from TableDistances.jl, and
17+
a random number generator `rng` to obtain reproducible results.
18+
19+
## Examples
20+
21+
```julia
22+
KMedoids(3)
23+
KMedoids(4, maxiter=20)
24+
KMedoids(5, weights=Dict(:col1 => 1.0, :col2 => 2.0))
25+
```
26+
27+
## References
28+
29+
* Kaufman, L. & Rousseeuw, P. J. 1990. [Partitioning Around Medoids (Program PAM)]
30+
(https://onlinelibrary.wiley.com/doi/10.1002/9780470316801.ch2)
31+
32+
* Kaufman, L. & Rousseeuw, P. J. 1991. [Finding Groups in Data: An Introduction to Cluster Analysis]
33+
(https://www.jstor.org/stable/2532178)
34+
"""
35+
struct KMedoids{W,RNG} <: StatelessFeatureTransform
36+
k::Int
37+
tol::Float64
38+
maxiter::Int
39+
weights::W
40+
rng::RNG
41+
end
42+
43+
function KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng())
44+
# sanity checks
45+
_assert(k > 0, "number of clusters must be positive")
46+
_assert(tol > 0, "tolerance on relative change must be positive")
47+
_assert(maxiter > 0, "maximum number of iterations must be positive")
48+
KMedoids(k, tol, maxiter, weights, rng)
49+
end
50+
51+
parameters(transform::KMedoids) = (; k=transform.k)
52+
53+
function applyfeat(transform::KMedoids, feat, prep)
54+
# retrieve parameters
55+
k = transform.k
56+
tol = transform.tol
57+
maxiter = transform.maxiter
58+
weights = transform.weights
59+
rng = transform.rng
60+
61+
# number of observations
62+
nobs = _nrows(feat)
63+
64+
# sanity checks
65+
k > nobs && throw(ArgumentError("requested number of clusters > number of observations"))
66+
67+
# normalize variables
68+
stdfeat = feat |> StdFeats()
69+
70+
# define table distance
71+
td = TableDistance(normalize=false, weights=weights)
72+
73+
# initialize medoids
74+
medoids = sample(rng, 1:nobs, k, replace=false)
75+
76+
# retrieve distance type
77+
s = Tables.subset(stdfeat, 1:1)
78+
D = eltype(pairwise(td, s))
79+
80+
# pre-allocate memory for labels and distances
81+
labels = fill(0, nobs)
82+
dists = fill(typemax(D), nobs)
83+
84+
# main loop
85+
iter = 0
86+
δcur = mean(dists)
87+
while iter < maxiter
88+
# update labels and medoids
89+
_updatelabels!(td, stdfeat, medoids, labels, dists)
90+
_updatemedoids!(td, stdfeat, medoids, labels)
91+
92+
# average distance to medoids
93+
δnew = mean(dists)
94+
95+
# break upon convergence
96+
abs(δnew - δcur) / δcur < tol && break
97+
98+
# update and continue
99+
δcur = δnew
100+
iter += 1
101+
end
102+
103+
newfeat = (; cluster=labels) |> Tables.materializer(feat)
104+
105+
newfeat, nothing
106+
end
107+
108+
function _updatelabels!(td, table, medoids, labels, dists)
109+
for (k, mₖ) in enumerate(medoids)
110+
inds = 1:_nrows(table)
111+
112+
X = Tables.subset(table, inds)
113+
μ = Tables.subset(table, [mₖ])
114+
115+
δ = pairwise(td, X, μ)
116+
117+
@inbounds for i in inds
118+
if δ[i] < dists[i]
119+
dists[i] = δ[i]
120+
labels[i] = k
121+
end
122+
end
123+
end
124+
end
125+
126+
function _updatemedoids!(td, table, medoids, labels)
127+
for k in eachindex(medoids)
128+
inds = findall(isequal(k), labels)
129+
130+
X = Tables.subset(table, inds)
131+
132+
j = _medoid(td, X)
133+
134+
@inbounds medoids[k] = inds[j]
135+
end
136+
end
137+
138+
function _medoid(td, table)
139+
Δ = pairwise(td, table)
140+
_, j = findmin(sum, eachcol(Δ))
141+
j
142+
end

test/transforms.jl

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ transformfiles = [
3131
"functional.jl",
3232
"eigenanalysis.jl",
3333
"projectionpursuit.jl",
34+
"kmedoids.jl",
3435
"closure.jl",
3536
"remainder.jl",
3637
"compose.jl",

test/transforms/kmedoids.jl

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
@testset "KMedoids" begin
2+
@test !isrevertible(KMedoids(3))
3+
4+
@test TT.parameters(KMedoids(3)) == (; k=3)
5+
6+
# basic test with continuous variables
7+
a = [randn(100); 10 .+ randn(100)]
8+
b = [randn(100); 10 .+ randn(100)]
9+
t = Table(; a, b)
10+
n = t |> KMedoids(2; rng)
11+
i1 = findall(isequal(1), n.cluster)
12+
i2 = findall(isequal(2), n.cluster)
13+
@test mean(t.a[i1]) > 5
14+
@test mean(t.b[i1]) > 5
15+
@test mean(t.a[i2]) < 5
16+
@test mean(t.b[i2]) < 5
17+
18+
# test with mixed variables
19+
a = [1, 2, 3]
20+
b = [1.0, 2.0, 3.0]
21+
c = ["a", "b", "c"]
22+
t = Table(; a, b, c)
23+
n = t |> KMedoids(3; rng)
24+
@test sort(n.cluster) == [1, 2, 3]
25+
end

0 commit comments

Comments
 (0)