Skip to content

Separate as() and into() #78

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/FunSQL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ export
funsql_from,
funsql_fun,
funsql_group,
funsql_hide,
funsql_highlight,
funsql_in,
funsql_into,
funsql_iterate,
funsql_is_not_null,
funsql_is_null,
Expand All @@ -80,6 +82,7 @@ export
funsql_rank,
funsql_row_number,
funsql_select,
funsql_show,
funsql_sort,
funsql_sum,
funsql_with
Expand Down
168 changes: 88 additions & 80 deletions src/link.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,28 @@ struct LinkContext
knot_refs)
end

function link(n::SQLNode)
@dissect(n, WithContext(over = over, catalog = catalog)) || throw(ILLFormedError())
ctx = LinkContext(catalog)
t = row_type(over)
function _select(t::RowType)
refs = SQLNode[]
t.visible || return refs
for (f, ft) in t.fields
if ft isa ScalarType
ft.visible || continue
push!(refs, Get(f))
else
nested_refs = _select(ft)
for nested_ref in nested_refs
push!(refs, Nested(over = nested_ref, name = f))
end
end
end
refs
end

function link(n::SQLNode)
@dissect(n, WithContext(over = over, catalog = catalog)) || throw(ILLFormedError())
ctx = LinkContext(catalog)
t = row_type(over)
refs = _select(t)
over′ = Linked(refs, over = link(dismantle(over, ctx), ctx, refs))
WithContext(over = over′, catalog = catalog, defs = ctx.defs)
end
Expand Down Expand Up @@ -114,19 +126,15 @@ function dismantle(n::GroupNode, ctx)
Group(over = over′, by = by′, sets = n.sets, name = n.name, label_map = n.label_map)
end

function dismantle(n::IterateNode, ctx)
function dismantle(n::IntoNode, ctx)
over′ = dismantle(n.over, ctx)
iterator′ = dismantle(n.iterator, ctx)
Iterate(over = over′, iterator = iterator′)
Into(over = over′, name = n.name)
end

function dismantle(n::JoinNode, ctx)
rt = row_type(n.joinee)
router = JoinRouter(Set(keys(rt.fields)), !isa(rt.group, EmptyType))
function dismantle(n::IterateNode, ctx)
over′ = dismantle(n.over, ctx)
joinee′ = dismantle(n.joinee, ctx)
on′ = dismantle_scalar(n.on, ctx)
RoutedJoin(over = over′, joinee = joinee′, on = on′, router = router, left = n.left, right = n.right, optional = n.optional)
iterator′ = dismantle(n.iterator, ctx)
Iterate(over = over′, iterator = iterator′)
end

function dismantle(n::LimitNode, ctx)
Expand Down Expand Up @@ -172,6 +180,13 @@ function dismantle_scalar(n::ResolvedNode, ctx)
end
end

function dismantle(n::RoutedJoinNode, ctx)
over′ = dismantle(n.over, ctx)
joinee′ = dismantle(n.joinee, ctx)
on′ = dismantle_scalar(n.on, ctx)
RoutedJoin(over = over′, joinee = joinee′, on = on′, name = n.name, left = n.left, right = n.right, optional = n.optional)
end

function dismantle(n::SelectNode, ctx)
over′ = dismantle(n.over, ctx)
args′ = dismantle_scalar(n.args, ctx)
Expand Down Expand Up @@ -219,16 +234,7 @@ function link(n::AppendNode, ctx)
end

function link(n::AsNode, ctx)
refs = SQLNode[]
for ref in ctx.refs
if @dissect(ref, over |> Nested(name = name))
@assert name == n.name
push!(refs, over)
else
error()
end
end
over′ = link(n.over, ctx, refs)
over′ = link(n.over, ctx)
As(over = over′, name = n.name)
end

Expand Down Expand Up @@ -276,10 +282,8 @@ function link(n::FromIterateNode, ctx)
end

function link(n::FromTableExpressionNode, ctx)
refs = ctx.cte_refs[(n.name, n.depth)]
for ref in ctx.refs
push!(refs, Nested(over = ref, name = n.name))
end
cte_refs = ctx.cte_refs[(n.name, n.depth)]
append!(cte_refs, ctx.refs)
n
end

Expand Down Expand Up @@ -320,6 +324,20 @@ function link(n::GroupNode, ctx)
Group(over = over′, by = n.by, sets = n.sets, name = n.name, label_map = n.label_map)
end

function link(n::IntoNode, ctx)
refs = SQLNode[]
for ref in ctx.refs
if @dissect(ref, over |> Nested(name = name))
@assert name == n.name
push!(refs, over)
else
error()
end
end
over′ = link(n.over, ctx, refs)
Into(over = over′, name = n.name)
end

function link(n::IterateNode, ctx)
iterator′ = n.iterator
defs = copy(ctx.defs)
Expand Down Expand Up @@ -351,53 +369,6 @@ function link(n::IterateNode, ctx)
Padding(over = n′)
end

function route(r::JoinRouter, ref::SQLNode)
if @dissect(ref, over |> Nested(name = name)) && name in r.label_set
return 1
end
if @dissect(ref, Get(name = name)) && name in r.label_set
return 1
end
if @dissect(ref, over |> Agg()) && r.group
return 1
end
return -1
end

function link(n::RoutedJoinNode, ctx)
lrefs = SQLNode[]
rrefs = SQLNode[]
for ref in ctx.refs
turn = route(n.router, ref)
push!(turn < 0 ? lrefs : rrefs, ref)
end
if n.optional && isempty(rrefs)
return link(n.over, ctx)
end
ln_ext_refs = length(lrefs)
rn_ext_refs = length(rrefs)
refs′ = SQLNode[]
lateral_refs = SQLNode[]
gather!(n.joinee, ctx, lateral_refs)
append!(lrefs, lateral_refs)
lateral = !isempty(lateral_refs)
gather!(n.on, ctx, refs′)
for ref in refs′
turn = route(n.router, ref)
push!(turn < 0 ? lrefs : rrefs, ref)
end
over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs))
joinee′ = Linked(rrefs, rn_ext_refs, over = link(n.joinee, ctx, rrefs))
RoutedJoinNode(
over = over′,
joinee = joinee′,
on = n.on,
router = n.router,
left = n.left,
right = n.right,
lateral = lateral)
end

function link(n::LimitNode, ctx)
over′ = Linked(ctx.refs, over = link(n.over, ctx))
Limit(over = over′, offset = n.offset, limit = n.limit)
Expand Down Expand Up @@ -446,6 +417,46 @@ function link(n::PartitionNode, ctx)
Partition(over = over′, by = n.by, order_by = n.order_by, frame = n.frame, name = n.name)
end

function link(n::RoutedJoinNode, ctx)
lrefs = SQLNode[]
rrefs = SQLNode[]
for ref in ctx.refs
if @dissect(ref, over |> Nested(name = name)) && name === n.name
push!(rrefs, ref)
else
push!(lrefs, ref)
end
end
if n.optional && isempty(rrefs)
return link(n.over, ctx)
end
ln_ext_refs = length(lrefs)
rn_ext_refs = length(rrefs)
refs′ = SQLNode[]
lateral_refs = SQLNode[]
gather!(n.joinee, ctx, lateral_refs)
append!(lrefs, lateral_refs)
lateral = !isempty(lateral_refs)
gather!(n.on, ctx, refs′)
for ref in refs′
if @dissect(ref, over |> Nested(name = name)) && name === n.name
push!(rrefs, ref)
else
push!(lrefs, ref)
end
end
over′ = Linked(lrefs, ln_ext_refs, over = link(n.over, ctx, lrefs))
joinee′ = Linked(rrefs, rn_ext_refs, over = link(Into(over = n.joinee, name = n.name), ctx, rrefs))
RoutedJoinNode(
over = over′,
joinee = joinee′,
on = n.on,
name = n.name,
left = n.left,
right = n.right,
lateral = lateral)
end

function link(n::SelectNode, ctx)
refs = SQLNode[]
gather!(n.args, ctx, refs)
Expand Down Expand Up @@ -540,12 +551,9 @@ end
function gather!(n::IsolatedNode, ctx)
def = ctx.defs[n.idx]
!@dissect(def, Linked()) || return
refs = SQLNode[]
for (f, ft) in n.type.fields
if ft isa ScalarType
push!(refs, Get(f))
break
end
refs = _select(n.type)
if !isempty(refs)
refs = refs[1:1]
end
def′ = Linked(refs, over = link(def, ctx, refs))
ctx.defs[n.idx] = def′
Expand Down
2 changes: 2 additions & 0 deletions src/nodes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ include("nodes/get.jl")
include("nodes/group.jl")
include("nodes/highlight.jl")
include("nodes/internal.jl")
include("nodes/into.jl")
include("nodes/iterate.jl")
include("nodes/join.jl")
include("nodes/limit.jl")
Expand All @@ -704,6 +705,7 @@ include("nodes/order.jl")
include("nodes/over.jl")
include("nodes/partition.jl")
include("nodes/select.jl")
include("nodes/show.jl")
include("nodes/sort.jl")
include("nodes/variable.jl")
include("nodes/where.jl")
Expand Down
17 changes: 8 additions & 9 deletions src/nodes/as.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ AsNode(name; over = nothing) =
As(name; over = nothing)
name => over

In a scalar context, `As` specifies the name of the output column. When
applied to tabular data, `As` wraps the data in a nested record.
`As` specifies the name of the output column.

The arrow operator (`=>`) is a shorthand notation for `As`.

Expand All @@ -37,19 +36,19 @@ SELECT "person_1"."person_id" AS "id"
FROM "person" AS "person_1"
```

*Show all patients together with their state of residence.*
*Show all patients together with their primary care provider.*

```jldoctest
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :location_id]);
julia> person = SQLTable(:person, columns = [:person_id, :year_of_birth, :provider_id]);

julia> location = SQLTable(:location, columns = [:location_id, :state]);
julia> provider = SQLTable(:provider, columns = [:provider_id, :provider_name]);

julia> q = From(:person) |>
Join(From(:location) |> As(:location),
on = Get.location_id .== Get.location.location_id) |>
Select(Get.person_id, Get.location.state);
Join(From(:provider) |> As(:pcp),
on = Get.provider_id .== Get.pcp.provider_id) |>
Select(Get.person_id, Get.pcp.provider_name);

julia> print(render(q, tables = [person, location]))
julia> print(render(q, tables = [person, provider]))
SELECT
"person_1"."person_id",
"location_1"."state"
Expand Down
20 changes: 6 additions & 14 deletions src/nodes/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,30 +202,22 @@ PrettyPrinting.quoteof(n::FromFunctionNode, ctx::QuoteContext) =
Expr(:kw, :columns, Expr(:vect, [QuoteNode(col) for col in n.columns]...)))

# Annotated Join node.
struct JoinRouter
label_set::Set{Symbol}
group::Bool
end

PrettyPrinting.quoteof(r::JoinRouter) =
Expr(:call, nameof(JoinRouter), quoteof(r.label_set), quoteof(r.group))

mutable struct RoutedJoinNode <: TabularNode
over::Union{SQLNode, Nothing}
joinee::SQLNode
on::SQLNode
router::JoinRouter
name::Symbol
left::Bool
right::Bool
lateral::Bool
optional::Bool

RoutedJoinNode(; over, joinee, on, router, left, right, lateral = false, optional = false) =
new(over, joinee, on, router, left, right, lateral, optional)
RoutedJoinNode(; over, joinee, on, name = label(joinee), left, right, lateral = false, optional = false) =
new(over, joinee, on, name, left, right, lateral, optional)
end

RoutedJoinNode(joinee, on; over = nothing, router, left = false, right = false, lateral = false, optional = false) =
RoutedJoinNode(over = over, joinee = joinee, on = on, router, left = left, right = right, lateral = lateral, optional = optional)
RoutedJoinNode(joinee, on; over = nothing, name = label(joinee), left = false, right = false, lateral = false, optional = false) =
RoutedJoinNode(over = over, name = name, on = on, router, left = left, right = right, lateral = lateral, optional = optional)

RoutedJoin(args...; kws...) =
RoutedJoinNode(args...; kws...) |> SQLNode
Expand All @@ -235,7 +227,7 @@ function PrettyPrinting.quoteof(n::RoutedJoinNode, ctx::QuoteContext)
if !ctx.limit
push!(ex.args, quoteof(n.joinee, ctx))
push!(ex.args, quoteof(n.on, ctx))
push!(ex.args, Expr(:kw, :router, quoteof(n.router)))
push!(ex.args, Expr(:kw, :name, QuoteNode(n.name)))
if n.left
push!(ex.args, Expr(:kw, :left, n.left))
end
Expand Down
39 changes: 39 additions & 0 deletions src/nodes/into.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Wrap the output into a nested record.

mutable struct IntoNode <: TabularNode
over::Union{SQLNode, Nothing}
name::Symbol

IntoNode(;
over = nothing,
name::Union{Symbol, AbstractString}) =
new(over, Symbol(name))
end

IntoNode(name; over = nothing) =
IntoNode(over = over, name = name)

"""
Into(; over = nothing, name)
Into(name; over = nothing)

`Into` wraps output columns in a nested record.
"""
Into(args...; kws...) =
IntoNode(args...; kws...) |> SQLNode

const funsql_into = Into

dissect(scr::Symbol, ::typeof(Into), pats::Vector{Any}) =
dissect(scr, IntoNode, pats)

function PrettyPrinting.quoteof(n::IntoNode, ctx::QuoteContext)
ex = Expr(:call, nameof(Into), quoteof(n.name))
if n.over !== nothing
ex = Expr(:call, :|>, quoteof(n.over, ctx), ex)
end
ex
end

label(n::IntoNode) =
n.name
Loading
Loading