Skip to content

Commit da28402

Browse files
committed
Update disjoint set module to include non-integer terms
1 parent f07317e commit da28402

File tree

2 files changed

+50
-7
lines changed

2 files changed

+50
-7
lines changed

lib/algorithms/disjoint_set.ex

+49-7
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,20 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
2828
}
2929
3030
"""
31-
@spec new(non_neg_integer()) :: t()
31+
@spec new(non_neg_integer() | List.t()) :: t()
3232
def new(0), do: %__MODULE__{}
3333

34-
def new(size) do
34+
def new(size) when is_integer(size) do
3535
%__MODULE__{
36-
ranks: 0..(size - 1) |> Enum.map(&{&1, 1}) |> Enum.into(%{}),
37-
parents: 0..(size - 1) |> Enum.map(&{&1, &1}) |> Enum.into(%{})
36+
ranks: 0..(size - 1) |> Map.new(&{&1, 1}),
37+
parents: 0..(size - 1) |> Map.new(&{&1, &1})
38+
}
39+
end
40+
41+
def new(lst) when is_list(lst) do
42+
%__MODULE__{
43+
ranks: lst |> Map.new(&{&1, 1}),
44+
parents: lst |> Map.new(&{&1, &1})
3845
}
3946
end
4047

@@ -97,7 +104,8 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
97104
end
98105

99106
@doc """
100-
Performs a union between two elements and returns the updated set.
107+
Performs a union between two elements and returns the updated set. `:error` case is matched so that it fails
108+
in a piped flow.
101109
102110
## Example
103111
@@ -118,9 +126,12 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
118126
iex> DisjointSet.new(1) |> DisjointSet.union(100, 200)
119127
:error
120128
129+
iex> DisjointSet.union(:error, 100, 200)
130+
:error
131+
121132
"""
122-
@spec union(t(), value(), value()) :: t() | :error
123-
def union(disjoint_set, a, b) do
133+
@spec union(t() | :error, value(), value()) :: t() | :error
134+
def union(%__MODULE__{} = disjoint_set, a, b) do
124135
with {root_a, disjoint_set} <- find(disjoint_set, a),
125136
{root_b, disjoint_set} <- find(disjoint_set, b) do
126137
union_by_rank(disjoint_set, root_a, root_b)
@@ -129,6 +140,37 @@ defmodule AdventOfCode.Algorithms.DisjointSet do
129140
end
130141
end
131142

143+
def union(:error, _, _), do: :error
144+
145+
@doc """
146+
Returns the connected components of a set of data. `:error` case is matched so that it fails
147+
in a piped flow.
148+
149+
## Example
150+
151+
iex> DisjointSet.new([{0, 0}, {0, 1}, {0, 2}, {10, 11}, {10, 12}, {100, 200}])
152+
...> |> DisjointSet.union({0, 0}, {0, 1})
153+
...> |> DisjointSet.union({0, 1}, {0, 2})
154+
...> |> DisjointSet.union({10, 11}, {10, 12})
155+
...> |> DisjointSet.components()
156+
[MapSet.new([{0, 0}, {0, 1}, {0, 2}]), MapSet.new([{10, 11}, {10, 12}]), MapSet.new([{100, 200}])]
157+
158+
iex> DisjointSet.new(10)
159+
...> |> DisjointSet.union(20, 30)
160+
...> |> DisjointSet.components()
161+
:error
162+
163+
"""
164+
@spec components(t() | :error) :: [[term()]]
165+
def components(%__MODULE__{parents: parents}) do
166+
parents
167+
|> Enum.group_by(&elem(&1, 1), fn {a, _} -> a end)
168+
|> Map.values()
169+
|> Enum.map(&Enum.into(&1, %MapSet{}))
170+
end
171+
172+
def components(:error), do: :error
173+
132174
defp union_by_rank(disjoint_set, parent, parent), do: disjoint_set
133175

134176
defp union_by_rank(%__MODULE__{ranks: ranks} = disjoint_set, root_a, root_b) do

test/algorithms/disjoint_set_test.exs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
defmodule AdventOfCode.Algorithms.DisjointSetTest do
22
use ExUnit.Case, async: true
3+
@moduletag :algorithm_disjoint_set
34

45
alias AdventOfCode.Algorithms.DisjointSet
56

0 commit comments

Comments
 (0)