Skip to content

Commit 8a912ab

Browse files
authored
Merge pull request #756 from lucidfrontier45/trie-improve
Improvements of Trie Structure
2 parents 84ed83a + ff24088 commit 8a912ab

File tree

4 files changed

+66
-4
lines changed

4 files changed

+66
-4
lines changed

docs/src/trie.md

+7
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,10 @@ given string, use:
3131
```julia
3232
seen_prefix(t::Trie, str) = any(v -> v.is_key, partial_path(t, str))
3333
```
34+
35+
`find_prefixes` can be used to find all keys which are prefixes of the given string.
36+
37+
```julia
38+
t = Trie(["A", "ABC", "ABCD", "BCE"])
39+
find_prefixes(t, "ABCDE") # "A", "ABC", "ABCD"
40+
```

src/DataStructures.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ module DataStructures
2828
export heapify!, heapify, heappop!, heappush!, isheap
2929
export BinaryMinMaxHeap, popmin!, popmax!, popall!
3030

31-
export Trie, subtrie, keys_with_prefix, partial_path
31+
export Trie, subtrie, keys_with_prefix, partial_path, find_prefixes
3232

3333
export LinkedList, Nil, Cons, nil, cons, head, tail, list, filter, cat,
3434
reverse

src/trie.jl

+32-3
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,43 @@ end
111111
# since the root of the trie corresponds to a length 0 prefix of str.
112112
function Base.iterate(it::TrieIterator, (t, i) = (it.t, 0))
113113
if i == 0
114-
return it.t, (it.t, 1)
115-
elseif i == length(it.str) + 1 || !(it.str[i] in keys(t.children))
114+
return it.t, (it.t, firstindex(it.str))
115+
elseif i > lastindex(it.str) || !(it.str[i] in keys(t.children))
116116
return nothing
117117
else
118118
t = t.children[it.str[i]]
119-
return (t, (t, i + 1))
119+
return (t, (t, nextind(it.str, i)))
120120
end
121121
end
122122

123123
partial_path(t::Trie, str::AbstractString) = TrieIterator(t, str)
124124
Base.IteratorSize(::Type{TrieIterator}) = Base.SizeUnknown()
125+
126+
"""
127+
find_prefixes(t::Trie, str::AbstractString)
128+
129+
Find all keys from the `Trie` that are prefix of the given string
130+
131+
# Examples
132+
```julia-repl
133+
julia> t = Trie(["A", "ABC", "ABCD", "BCE"])
134+
135+
julia> find_prefixes(t, "ABCDE")
136+
3-element Vector{AbstractString}:
137+
"A"
138+
"ABC"
139+
"ABCD"
140+
```
141+
"""
142+
function find_prefixes(t::Trie, str::AbstractString)
143+
prefixes = AbstractString[]
144+
it = partial_path(t, str)
145+
idx = 0
146+
for t in it
147+
if t.is_key
148+
push!(prefixes, str[firstindex(str):idx])
149+
end
150+
idx = nextind(str, idx)
151+
end
152+
return prefixes
153+
end

test/test_trie.jl

+26
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,30 @@
4040
@test collect(partial_path(t, "ro")) == [t0, t1, t2]
4141
@test collect(partial_path(t, "roa")) == [t0, t1, t2]
4242
end
43+
44+
@testset "partial_path iterator non-ascii" begin
45+
t = Trie(["東京都"])
46+
t0 = t
47+
t1 = t0.children['']
48+
t2 = t1.children['']
49+
t3 = t2.children['']
50+
@test collect(partial_path(t, "西")) == [t0]
51+
@test collect(partial_path(t, "東京都")) == [t0, t1, t2, t3]
52+
@test collect(partial_path(t, "東京都渋谷区")) == [t0, t1, t2, t3]
53+
@test collect(partial_path(t, "東京")) == [t0, t1, t2]
54+
@test collect(partial_path(t, "東京スカイツリー")) == [t0, t1, t2]
55+
end
56+
57+
@testset "find_prefixes" begin
58+
t = Trie(["A", "ABC", "ABD", "BCD"])
59+
prefixes = find_prefixes(t, "ABCDE")
60+
@test prefixes == ["A", "ABC"]
61+
end
62+
63+
@testset "find_prefixes non-ascii" begin
64+
t = Trie(["東京都", "東京都渋谷区", "東京都新宿区"])
65+
prefixes = find_prefixes(t, "東京都渋谷区東")
66+
@test prefixes == ["東京都", "東京都渋谷区"]
67+
end
68+
4369
end # @testset Trie

0 commit comments

Comments
 (0)