Skip to content

Commit ab9f959

Browse files
committed
Separate as() and into()
1 parent 058012e commit ab9f959

File tree

8 files changed

+165
-152
lines changed

8 files changed

+165
-152
lines changed

src/FunSQL.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ export
5555
funsql_group,
5656
funsql_highlight,
5757
funsql_in,
58+
funsql_into,
5859
funsql_iterate,
5960
funsql_is_not_null,
6061
funsql_is_null,

src/link.jl

Lines changed: 69 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -114,19 +114,15 @@ function dismantle(n::GroupNode, ctx)
114114
Group(over = over′, by = by′, sets = n.sets, name = n.name, label_map = n.label_map)
115115
end
116116

117-
function dismantle(n::IterateNode, ctx)
117+
function dismantle(n::IntoNode, ctx)
118118
over′ = dismantle(n.over, ctx)
119-
iterator′ = dismantle(n.iterator, ctx)
120-
Iterate(over = over′, iterator = iterator′)
119+
Into(over = over′, name = n.name)
121120
end
122121

123-
function dismantle(n::JoinNode, ctx)
124-
rt = row_type(n.joinee)
125-
router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType))
122+
function dismantle(n::IterateNode, ctx)
126123
over′ = dismantle(n.over, ctx)
127-
joinee′ = dismantle(n.joinee, ctx)
128-
on′ = dismantle_scalar(n.on, ctx)
129-
RoutedJoin(over = over′, joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional)
124+
iterator′ = dismantle(n.iterator, ctx)
125+
Iterate(over = over′, iterator = iterator′)
130126
end
131127

132128
function dismantle(n::LimitNode, ctx)
@@ -172,6 +168,13 @@ function dismantle_scalar(n::ResolvedNode, ctx)
172168
end
173169
end
174170

171+
function dismantle(n::RoutedJoinNode, ctx)
172+
over′ = dismantle(n.over, ctx)
173+
joinee′ = dismantle(n.joinee, ctx)
174+
on′ = dismantle_scalar(n.on, ctx)
175+
RoutedJoin(over = over′, joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional)
176+
end
177+
175178
function dismantle(n::SelectNode, ctx)
176179
over′ = dismantle(n.over, ctx)
177180
args′ = dismantle_scalar(n.args, ctx)
@@ -219,16 +222,7 @@ function link(n::AppendNode, ctx)
219222
end
220223

221224
function link(n::AsNode, ctx)
222-
refs = SQLNode[]
223-
for ref in ctx.refs
224-
if @dissect(ref, over |> Nested(name = name))
225-
@assert name == n.name
226-
push!(refs, over)
227-
else
228-
error()
229-
end
230-
end
231-
over′ = link(n.over, ctx, refs)
225+
over′ = link(n.over, ctx)
232226
As(over = over′, name = n.name)
233227
end
234228

@@ -276,10 +270,8 @@ function link(n::FromIterateNode, ctx)
276270
end
277271

278272
function link(n::FromTableExpressionNode, ctx)
279-
refs = ctx.cte_refs[(n.name, n.depth)]
280-
for ref in ctx.refs
281-
push!(refs, Nested(over = ref, name = n.name))
282-
end
273+
cte_refs = ctx.cte_refs[(n.name, n.depth)]
274+
append!(cte_refs, ctx.refs)
283275
n
284276
end
285277

@@ -320,6 +312,20 @@ function link(n::GroupNode, ctx)
320312
Group(over = over′, by = n.by, sets = n.sets, name = n.name, label_map = n.label_map)
321313
end
322314

315+
function link(n::IntoNode, ctx)
316+
refs = SQLNode[]
317+
for ref in ctx.refs
318+
if @dissect(ref, over |> Nested(name = name))
319+
@assert name == n.name
320+
push!(refs, over)
321+
else
322+
error()
323+
end
324+
end
325+
over′ = link(n.over, ctx, refs)
326+
Into(over = over′, name = n.name)
327+
end
328+
323329
function link(n::IterateNode, ctx)
324330
iterator′ = n.iterator
325331
defs = copy(ctx.defs)
@@ -351,53 +357,6 @@ function link(n::IterateNode, ctx)
351357
Padding(over = n′)
352358
end
353359

354-
function route(r::JoinRouter, ref::SQLNode)
355-
if @dissect(ref, over |> Nested(name = name)) && name in r.label_set
356-
return 1
357-
end
358-
if @dissect(ref, Get(name = name)) && name in r.label_set
359-
return 1
360-
end
361-
if @dissect(ref, over |> Agg()) && r.group
362-
return 1
363-
end
364-
return -1
365-
end
366-
367-
function link(n::RoutedJoinNode, ctx)
368-
lrefs = SQLNode[]
369-
rrefs = SQLNode[]
370-
for ref in ctx.refs
371-
turn = route(n.router, ref)
372-
push!(turn < 0 ? lrefs : rrefs, ref)
373-
end
374-
if n.optional && isempty(rrefs)
375-
return link(n.over, ctx)
376-
end
377-
ln_ext_refs = length(lrefs)
378-
rn_ext_refs = length(rrefs)
379-
refs′ = SQLNode[]
380-
lateral_refs = SQLNode[]
381-
gather!(n.joinee, ctx, lateral_refs)
382-
append!(lrefs, lateral_refs)
383-
lateral = !isempty(lateral_refs)
384-
gather!(n.on, ctx, refs′)
385-
for ref in refs′
386-
turn = route(n.router, ref)
387-
push!(turn < 0 ? lrefs : rrefs, ref)
388-
end
389-
over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs))
390-
joinee′ = Linked(rrefs, rn_ext_refs, over = link(n.joinee, ctx, rrefs))
391-
RoutedJoinNode(
392-
over = over′,
393-
joinee = joinee′,
394-
on = n.on,
395-
router = n.router,
396-
left = n.left,
397-
right = n.right,
398-
lateral = lateral)
399-
end
400-
401360
function link(n::LimitNode, ctx)
402361
over′ = Linked(ctx.refs, over = link(n.over, ctx))
403362
Limit(over = over′, offset = n.offset, limit = n.limit)
@@ -446,6 +405,46 @@ function link(n::PartitionNode, ctx)
446405
Partition(over = over′, by = n.by, order_by = n.order_by, frame = n.frame, name = n.name)
447406
end
448407

408+
function link(n::RoutedJoinNode, ctx)
409+
lrefs = SQLNode[]
410+
rrefs = SQLNode[]
411+
for ref in ctx.refs
412+
if @dissect(ref, over |> Nested(name = name)) && name === n.name
413+
push!(rrefs, ref)
414+
else
415+
push!(lrefs, ref)
416+
end
417+
end
418+
if n.optional && isempty(rrefs)
419+
return link(n.over, ctx)
420+
end
421+
ln_ext_refs = length(lrefs)
422+
rn_ext_refs = length(rrefs)
423+
refs′ = SQLNode[]
424+
lateral_refs = SQLNode[]
425+
gather!(n.joinee, ctx, lateral_refs)
426+
append!(lrefs, lateral_refs)
427+
lateral = !isempty(lateral_refs)
428+
gather!(n.on, ctx, refs′)
429+
for ref in refs′
430+
if @dissect(ref, over |> Nested(name = name)) && name === n.name
431+
push!(rrefs, ref)
432+
else
433+
push!(lrefs, ref)
434+
end
435+
end
436+
over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs))
437+
joinee′ = Linked(rrefs, rn_ext_refs, over = link(Into(over = n.joinee, name = n.name), ctx, rrefs))
438+
RoutedJoinNode(
439+
over = over′,
440+
joinee = joinee′,
441+
on = n.on,
442+
name = n.name,
443+
left = n.left,
444+
right = n.right,
445+
lateral = lateral)
446+
end
447+
449448
function link(n::SelectNode, ctx)
450449
refs = SQLNode[]
451450
gather!(n.args, ctx, refs)

src/nodes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ include("nodes/get.jl")
696696
include("nodes/group.jl")
697697
include("nodes/highlight.jl")
698698
include("nodes/internal.jl")
699+
include("nodes/into.jl")
699700
include("nodes/iterate.jl")
700701
include("nodes/join.jl")
701702
include("nodes/limit.jl")

src/nodes/as.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,7 @@ AsNode(name; over = nothing) =
1818
As(name; over = nothing)
1919
name => over
2020
21-
In a scalar context, `As` specifies the name of the output column. When
22-
applied to tabular data, `As` wraps the data in a nested record.
21+
`As` specifies the name of the output column.
2322
2423
The arrow operator (`=>`) is a shorthand notation for `As`.
2524
@@ -37,19 +36,19 @@ SELECT "person_1"."person_id" AS "id"
3736
FROM "person" AS "person_1"
3837
```
3938
40-
*Show all patients together with their state of residence.*
39+
*Show all patients together with their primary care provider.*
4140
4241
```jldoctest
43-
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]);
42+
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]);
4443
45-
julia> location = SQLTable(:location, columns = [:location_id, :state]);
44+
julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]);
4645
4746
julia> q = From(:person) |>
48-
Join(From(:location) |> As(:location),
49-
on = Get.location_id .== Get.location.location_id) |>
50-
Select(Get.person_id, Get.location.state);
47+
Join(From(:provider) |> As(:pcp),
48+
on = Get.provider_id .== Get.pcp.provider_id) |>
49+
Select(Get.person_id, Get.pcp.provider_name);
5150
52-
julia> print(render(q, tables = [person, location]))
51+
julia> print(render(q, tables = [person, provider]))
5352
SELECT
5453
"person_1"."person_id",
5554
"location_1"."state"

src/nodes/internal.jl

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -202,30 +202,22 @@ PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) =
202202
Expr(:kw, :columns, Expr(:vect, [QuoteNode(col) for col in n.columns]...)))
203203

204204
# Annotated Join node.
205-
struct JoinRouter
206-
label_set::Set{Symbol}
207-
group::Bool
208-
end
209-
210-
PrettyPrinting.quoteof(r::JoinRouter) =
211-
Expr(:call, nameof(JoinRouter), quoteof(r.label_set), quoteof(r.group))
212-
213205
mutable struct RoutedJoinNode <: TabularNode
214206
over::Union{SQLNode, Nothing}
215207
joinee::SQLNode
216208
on::SQLNode
217-
router::JoinRouter
209+
name::Symbol
218210
left::Bool
219211
right::Bool
220212
lateral::Bool
221213
optional::Bool
222214

223-
RoutedJoinNode(; over, joinee, on, router, left, right, lateral = false, optional = false) =
224-
new(over, joinee, on, router, left, right, lateral, optional)
215+
RoutedJoinNode(; over, joinee, on, name = label(joinee), left, right, lateral = false, optional = false) =
216+
new(over, joinee, on, name, left, right, lateral, optional)
225217
end
226218

227-
RoutedJoinNode(joinee, on; over = nothing, router, left = false, right = false, lateral = false, optional = false) =
228-
RoutedJoinNode(over = over, joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional)
219+
RoutedJoinNode(joinee, on; over = nothing, name = label(joinee), left = false, right = false, lateral = false, optional = false) =
220+
RoutedJoinNode(over = over, name = name, on = on, router, left = left, right = right, lateral = lateral, optional = optional)
229221

230222
RoutedJoin(args...; kws...) =
231223
RoutedJoinNode(args...; kws...) |> SQLNode
@@ -235,7 +227,7 @@ function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext)
235227
if !ctx.limit
236228
push!(ex.args, quoteof(n.joinee, ctx))
237229
push!(ex.args, quoteof(n.on, ctx))
238-
push!(ex.args, Expr(:kw, :router, quoteof(n.router)))
230+
push!(ex.args, Expr(:kw, :name, QuoteNode(n.name)))
239231
if n.left
240232
push!(ex.args, Expr(:kw, :left, n.left))
241233
end

src/nodes/into.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Wrap the output into a nested record.
2+
3+
mutable struct IntoNode <: TabularNode
4+
over::Union{SQLNode, Nothing}
5+
name::Symbol
6+
7+
IntoNode(;
8+
over = nothing,
9+
name::Union{Symbol, AbstractString}) =
10+
new(over, Symbol(name))
11+
end
12+
13+
IntoNode(name; over = nothing) =
14+
IntoNode(over = over, name = name)
15+
16+
"""
17+
Into(; over = nothing, name)
18+
Into(name; over = nothing)
19+
20+
`Into` wraps output columns in a nested record.
21+
"""
22+
Into(args...; kws...) =
23+
IntoNode(args...; kws...) |> SQLNode
24+
25+
const funsql_into = Into
26+
27+
dissect(scr::Symbol, ::typeof(Into), pats::Vector{Any}) =
28+
dissect(scr, IntoNode, pats)
29+
30+
function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext)
31+
ex = Expr(:call, nameof(Into), quoteof(n.name))
32+
if n.over !== nothing
33+
ex = Expr(:call, :|>, quoteof(n.over, ctx), ex)
34+
end
35+
ex
36+
end
37+
38+
label(n::IntoNode) =
39+
n.name

src/resolve.jl

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -160,9 +160,8 @@ end
160160

161161
function resolve(n::AsNode, ctx)
162162
over′ = resolve(n.over, ctx)
163-
t = row_type(over′)
164163
n′ = As(name = n.name, over = over′)
165-
Resolved(RowType(FieldTypeMap(n.name => t)), over = n′)
164+
Resolved(type(over′), over = n′)
166165
end
167166

168167
function resolve_scalar(n::AsNode, ctx)
@@ -357,6 +356,13 @@ resolve(n::HighlightNode, ctx) =
357356
resolve_scalar(n::HighlightNode, ctx) =
358357
resolve_scalar(n.over, ctx)
359358

359+
function resolve(n::IntoNode, ctx)
360+
over′ = resolve(n.over, ctx)
361+
t = row_type(over′)
362+
n′ = Into(name = n.name, over = over′)
363+
Resolved(RowType(FieldTypeMap(n.name => t)), over = n′)
364+
end
365+
360366
function resolve(n::IterateNode, ctx)
361367
over′ = resolve(n.over, ResolveContext(ctx, knot_type = nothing, implicit_knot = false))
362368
t = row_type(over′)
@@ -374,21 +380,18 @@ end
374380
function resolve(n::JoinNode, ctx)
375381
over′ = resolve(n.over, ctx)
376382
lt = row_type(over′)
383+
name = label(n.joinee)
377384
joinee′ = resolve(n.joinee, ResolveContext(ctx, row_type = lt, implicit_knot = false))
378385
rt = row_type(joinee′)
379386
fields = FieldTypeMap()
380387
for (f, ft) in lt.fields
381-
fields[f] = get(rt.fields, f, ft)
388+
fields[f] = ft
382389
end
383-
for (f, ft) in rt.fields
384-
if !haskey(fields, f)
385-
fields[f] = ft
386-
end
387-
end
388-
group = rt.group isa EmptyType ? lt.group : rt.group
390+
fields[name] = rt
391+
group = lt.group
389392
t = RowType(fields, group)
390393
on′ = resolve_scalar(n.on, ctx, t)
391-
n′ = Join(over = over′, joinee = joinee′, on = on′, left = n.left, right = n.right, optional = n.optional)
394+
n′ = RoutedJoin(over = over′, joinee = joinee′, on = on′, name = name, left = n.left, right = n.right, optional = n.optional)
392395
Resolved(t, over = n′)
393396
end
394397

@@ -491,16 +494,7 @@ function resolve(n::Union{WithNode, WithExternalNode}, ctx)
491494
v = get(ctx.cte_types, name, nothing)
492495
depth = 1 + (v !== nothing ? v[1] : 0)
493496
t = row_type(args′[i])
494-
cte_t = get(t.fields, name, EmptyType())
495-
if !(cte_t isa RowType)
496-
throw(
497-
ReferenceError(
498-
REFERENCE_ERROR_TYPE.INVALID_TABLE_REFERENCE,
499-
name = name,
500-
path = get_path(ctx)))
501-
502-
end
503-
cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, cte_t))
497+
cte_types′ = Base.ImmutableDict(cte_types′, name => (depth, t))
504498
end
505499
ctx′ = ResolveContext(ctx, cte_types = cte_types′)
506500
over′ = resolve(n.over, ctx′)

0 commit comments

Comments
 (0)