Skip to content

Commit 4e3c06f

Browse files
authored
Merge pull request #28 from Yuan-Ru-Lin/add-affinity-propagation
Initial commit for implementation of Affinity Propagation
2 parents 0be6b1a + da6e69c commit 4e3c06f

File tree

4 files changed

+175
-6
lines changed

4 files changed

+175
-6
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
version:
18-
- '1.6'
18+
- '1.10'
1919
- '1'
2020
os:
2121
- ubuntu-latest

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@ version = "0.1.11"
66
[deps]
77
Clustering = "aaaa29a8-35af-508c-8bc3-b662a17a0fe5"
88
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
9+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
910
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
11+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1012

1113
[compat]
1214
Clustering = "0.15"
1315
Distances = "0.9, 0.10"
16+
LinearAlgebra = "1"
1417
MLJModelInterface = "1.4"
15-
julia = "1.6"
18+
StatsBase = "0.34"
19+
julia = "1.10"

src/MLJClusteringInterface.jl

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@ import MLJModelInterface: Continuous, Count, Finite, Multiclass, Table, OrderedF
1313
@mlj_model, metadata_model, metadata_pkg
1414

1515
using Distances
16+
using LinearAlgebra
17+
using StatsBase
1618

1719
# ===================================================================
1820
## EXPORTS
19-
export KMeans, KMedoids, DBSCAN, HierarchicalClustering
21+
export KMeans, KMedoids, AffinityPropagation, DBSCAN, HierarchicalClustering
2022

2123
# ===================================================================
2224
## CONSTANTS
@@ -95,7 +97,6 @@ function MMI.transform(model::KMedoids, fitresult, X)
9597
return MMI.table(X̃, prototype=X)
9698
end
9799

98-
99100
# # PREDICT FOR K_MEANS AND K_MEDOIDS
100101

101102
function MMI.predict(model::Union{KMeans,KMedoids}, fitresult, Xnew)
@@ -208,10 +209,66 @@ end
208209

209210
MMI.reporting_operations(::Type{<:HierarchicalClustering}) = (:predict,)
210211

212+
# # AFFINITY_PROPAGATION
213+
214+
@mlj_model mutable struct AffinityPropagation <: MMI.Static
215+
damp::Float64 = 0.5::(0.0 ≤ _ < 1.0)
216+
maxiter::Int = 200::(_ > 0)
217+
tol::Float64 = 1e-6::(_ > 0)
218+
preference::Union{Nothing,Float64} = nothing
219+
metric::SemiMetric = SqEuclidean()
220+
end
221+
222+
function MMI.predict(model::AffinityPropagation, ::Nothing, X)
223+
Xarray = MMI.matrix(X)'
224+
225+
# Compute similarity matrix using negative pairwise distances
226+
S = -pairwise(model.metric, Xarray, dims=2)
227+
228+
diagonal_element = if !isnothing(model.preference)
229+
model.preference
230+
else
231+
# Get the median out of all pairs of similarity, that is, values above
232+
# the diagonal line.
233+
# Such default choice is mentioned in the algorithm's wiki article
234+
iuppertri = triu!(trues(size(S)),1)
235+
median(S[iuppertri])
236+
end
237+
238+
fill!(view(S, diagind(S)), diagonal_element)
239+
240+
result = Cl.affinityprop(
241+
S,
242+
maxiter=model.maxiter,
243+
tol=model.tol,
244+
damp=model.damp
245+
)
246+
247+
# Get number of clusters and labels
248+
exemplars = result.exemplars
249+
k = length(exemplars)
250+
cluster_labels = MMI.categorical(1:k)
251+
252+
# Store exemplar points as centers (similar to KMeans/KMedoids)
253+
centers = view(Xarray, :, exemplars)
254+
255+
report = (
256+
exemplars=exemplars,
257+
centers=centers,
258+
cluster_labels=cluster_labels,
259+
iterations=result.iterations,
260+
converged=result.converged
261+
)
262+
263+
return MMI.categorical(result.assignments), report
264+
end
265+
266+
MMI.reporting_operations(::Type{<:AffinityPropagation}) = (:predict,)
267+
211268
# # METADATA
212269

213270
metadata_pkg.(
214-
(KMeans, KMedoids, DBSCAN, HierarchicalClustering),
271+
(KMeans, KMedoids, DBSCAN, HierarchicalClustering, AffinityPropagation),
215272
name="Clustering",
216273
uuid="aaaa29a8-35af-508c-8bc3-b662a17a0fe5",
217274
url="https://github.com/JuliaStats/Clustering.jl",
@@ -251,6 +308,13 @@ metadata_model(
251308
path = "$(PKG).HierarchicalClustering"
252309
)
253310

311+
metadata_model(
312+
AffinityPropagation,
313+
human_name = "Affinity Propagation clusterer",
314+
input_scitype = MMI.Table(Continuous),
315+
path = "$(PKG).AffinityPropagation"
316+
)
317+
254318
"""
255319
$(MMI.doc_header(KMeans))
256320
@@ -618,4 +682,73 @@ report(mach).cutter(h = 2.5)
618682
"""
619683
HierarchicalClustering
620684

685+
"""
686+
$(MMI.doc_header(AffinityPropagation))
687+
688+
[Affinity Propagation](https://en.wikipedia.org/wiki/Affinity_propagation) is a clustering algorithm based on the concept of "message passing" between data points. More information is available at the [Clustering.jl documentation](https://juliastats.org/Clustering.jl/stable/index.html). Use `predict` to get cluster assignments. Indices of the exemplars, their values, etc, are accessed from the machine report (see below).
689+
690+
This is a static implementation, i.e., it does not generalize to new data instances, and
691+
there is no training data. For clusterers that do generalize, see [`KMeans`](@ref) or
692+
[`KMedoids`](@ref).
693+
694+
In MLJ or MLJBase, create a machine with
695+
696+
mach = machine(model)
697+
698+
# Hyper-parameters
699+
700+
- `damp = 0.5`: damping factor
701+
702+
- `maxiter = 200`: maximum number of iteration
703+
704+
- `tol = 1e-6`: tolerance for converenge
705+
706+
- `preference = nothing`: the (single float) value of the diagonal elements of the similarity matrix. If unspecified, choose median (negative) similarity of all pairs as mentioned [here](https://en.wikipedia.org/wiki/Affinity_propagation#Algorithm)
707+
708+
- `metric = Distances.SqEuclidean()`: metric (see `Distances.jl` for available metrics)
709+
710+
# Operations
711+
712+
- `predict(mach, X)`: return cluster label assignments, as an unordered
713+
`CategoricalVector`. Here `X` is any table of input features (eg, a `DataFrame`) whose
714+
columns are of scitype `Continuous`; check column scitypes with `schema(X)`.
715+
716+
# Report
717+
718+
After calling `predict(mach)`, the fields of `report(mach)` are:
719+
720+
- exemplars: indices of the data picked as exemplars in `X`
721+
722+
- centers: positions of the exemplars in the feature space
723+
724+
- cluster_labels: labels of clusters given to each datum in `X`
725+
726+
- iterations: the number of iteration run by the algorithm
727+
728+
- converged: whether or not the algorithm converges by the maximum iteration
729+
730+
# Examples
731+
732+
```
733+
using MLJ
734+
735+
X, labels = make_moons(400, noise=0.9, rng=1)
736+
737+
AffinityPropagation = @load AffinityPropagation pkg=Clustering
738+
model = AffinityPropagation(preference=-10.0)
739+
mach = machine(model)
740+
741+
# compute and output cluster assignments for observations in `X`:
742+
yhat = predict(mach, X)
743+
744+
# Get the positions of the exemplars
745+
report(mach).centers
746+
747+
# Plot clustering result
748+
using GLMakie
749+
scatter(MLJ.matrix(X)', color=yhat.refs)
750+
```
751+
"""
752+
AffinityPropagation
753+
621754
end # module

test/runtests.jl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,40 @@ end
150150
@test report(mach).dendrogram.heights == dendro.heights
151151
end
152152

153+
# # AffinityPropagation
154+
155+
@testset "AffinityPropagation" begin
156+
X = table(stack(Iterators.partition(0.5:0.5:20, 5))')
157+
158+
# Test case 1: preference == median (negative) similarity (i.e. unspecified)
159+
mach = machine(AffinityPropagation())
160+
161+
yhat = predict(mach, X)
162+
@test yhat == [1, 1, 1, 1, 2, 2, 2, 2]
163+
164+
_report = report(mach)
165+
@test _report.exemplars == [2, 7]
166+
@test _report.centers == [3.0 15.5; 3.5 16.0; 4.0 16.5; 4.5 17.0; 5.0 17.5]
167+
@test _report.cluster_labels == [1, 2]
168+
@test _report.iterations == 50
169+
@test _report.converged == true
170+
171+
# Test case 2: |preference| too large
172+
mach2 = machine(AffinityPropagation(preference=-20.0))
173+
174+
yhat = predict(mach2, X)
175+
@test yhat == [1, 2, 3, 4, 5, 6, 7, 8]
176+
177+
_report = report(mach2)
178+
@test _report.exemplars == [1, 2, 3, 4, 5, 6, 7, 8]
179+
@test _report.centers == matrix(X)'
180+
@test _report.cluster_labels == [1, 2, 3, 4, 5, 6, 7, 8]
181+
@test _report.iterations == 32
182+
@test _report.converged == true
183+
end
184+
153185
@testset "MLJ interface" begin
154-
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering]
186+
models = [KMeans, KMedoids, DBSCAN, HierarchicalClustering, AffinityPropagation]
155187
failures, summary = MLJTestInterface.test(
156188
models,
157189
X;

0 commit comments

Comments
 (0)