Skip to content

Propagate query metadata to SQLCursor #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions docs/src/test/nodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -4244,8 +4244,9 @@ On the next stage, the query object is converted to a SQL syntax tree.
│ left = true) |>
│ SELECT(ID(:person_2) |> ID(:person_id),
│ ID(:visit_group_1) |> ID(:max) |> AS(:max_visit_start_date)) |>
│ WITH_CONTEXT(columns = [SQLColumn(:person_id),
│ SQLColumn(:max_visit_start_date)])
│ WITH_CONTEXT(shape = SQLTable(:person,
│ SQLColumn(:person_id),
│ SQLColumn(:max_visit_start_date)))
└ @ FunSQL …
=#

Expand Down Expand Up @@ -4281,6 +4282,8 @@ Finally, the SQL tree is serialized into SQL.
│ FROM "visit_occurrence" AS "visit_occurrence_1"
│ GROUP BY "visit_occurrence_1"."person_id"
│ ) AS "visit_group_1" ON ("person_2"."person_id" = "visit_group_1"."person_id")""",
│ columns = [SQLColumn(:person_id), SQLColumn(:max_visit_start_date)])
│ shape = SQLTable(:person,
│ SQLColumn(:person_id),
│ SQLColumn(:max_visit_start_date)))
└ @ FunSQL …
=#
22 changes: 12 additions & 10 deletions docs/src/test/other.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Other Tests


## `SQLConnection` and `SQLStatement`
## `SQLConnection`, `SQLStatement`, and `SQLCursor`

A `SQLConnection` object encapsulates a raw database connection together
with the database catalog.
Expand Down Expand Up @@ -38,7 +38,7 @@ a FunSQL-specific `SQLStatement` object.
q = From(:person)

stmt = DBInterface.prepare(conn, q)
#-> SQLStatement(SQLConnection( … ), SQLite.Stmt( … ))
#-> SQLStatement(SQLConnection( … ), SQLite.Stmt( … ), shape = SQLTable( … ))

DBInterface.getconnection(stmt)
#-> SQLConnection( … )
Expand All @@ -47,7 +47,7 @@ The output of the statement is wrapped in a FunSQL-specific `SQLCursor`
object.

cr = DBInterface.execute(stmt)
#-> SQLCursor(SQLite.Query{false}( … ))
#-> SQLCursor(SQLite.Query{false}( … ), shape = SQLTable( … ))

`SQLCursor` implements standard interfaces by delegating supported methods
to the wrapped cursor object.
Expand Down Expand Up @@ -100,10 +100,10 @@ by name.
Where(Get.year_of_birth .>= Var.YEAR)

stmt = DBInterface.prepare(conn, q)
#-> SQLStatement(SQLConnection( … ), SQLite.Stmt( … ), vars = [:YEAR])
#-> SQLStatement(SQLConnection( … ), SQLite.Stmt( … ), vars = [:YEAR], shape = SQLTable( … ))

DBInterface.execute(stmt, YEAR = 1950)
#-> SQLCursor(SQLite.Query{false}( … ))
#-> SQLCursor(SQLite.Query{false}( … ), shape = SQLTable( … ))

DBInterface.close!(stmt)

Expand Down Expand Up @@ -425,14 +425,16 @@ A completely custom dialect can be specified.
String(sql)
#-> "SELECT * FROM person"

`SQLString` may carry a vector `columns` describing the output columns of
the query.
`SQLString` may specify the `shape` describing the output columns of the query.

sql = SQLString("SELECT person_id FROM person", columns = [SQLColumn(:person_id)])
#-> SQLString("SELECT person_id FROM person", columns = […1 column…])
sql = SQLString("SELECT person_id FROM person", shape = SQLTable(:person, SQLColumn(:person_id)))
#-> SQLString("SELECT person_id FROM person", shape = SQLTable(person, …1 column…))

display(sql)
#-> SQLString("SELECT person_id FROM person", columns = [SQLColumn(:person_id)])
#=>
SQLString("SELECT person_id FROM person",
shape = SQLTable(:person, SQLColumn(:person_id)))
=#

When the query has parameters, `SQLString` should include a vector of
parameter names in the order they should appear in `DBInterface.execute` call.
Expand Down
10 changes: 5 additions & 5 deletions src/clauses/internal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

struct WithContextClause <: AbstractSQLClause
dialect::SQLDialect
columns::Union{Vector{SQLColumn}, Nothing}
shape::SQLTable

WithContextClause(; dialect, columns = nothing) =
new(dialect, columns)
WithContextClause(; dialect, shape = SQLTable(name = :_, columns = [])) =
new(dialect, shape)
end

const WITH_CONTEXT = SQLSyntaxCtor{WithContextClause}(:WITH_CONTEXT)
Expand All @@ -17,8 +17,8 @@ function PrettyPrinting.quoteof(c::WithContextClause, ctx::QuoteContext)
if c.dialect !== default_dialect
push!(ex.args, Expr(:kw, :dialect, quoteof(c.dialect)))
end
if c.columns !== nothing
push!(ex.args, Expr(:kw, :columns, Expr(:vect, Any[quoteof(col) for col in c.columns]...)))
if c.shape.name !== :_ || !isempty(c.shape.columns) || !isempty(c.shape.metadata)
push!(ex.args, Expr(:kw, :shape, quoteof(c.shape)))
end
ex
end
90 changes: 78 additions & 12 deletions src/connections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,14 @@ struct SQLStatement{RawConnType, RawStmtType} <: DBInterface.Statement
conn::SQLConnection{RawConnType}
raw::RawStmtType
vars::Vector{Symbol}
shape::SQLTable

SQLStatement{RawConnType, RawStmtType}(conn::SQLConnection{RawConnType}, raw::RawStmtType; vars = Symbol[]) where {RawConnType, RawStmtType} =
new(conn, raw, vars)
SQLStatement{RawConnType, RawStmtType}(conn::SQLConnection{RawConnType}, raw::RawStmtType; vars = Symbol[], shape = SQLTable(name = :_, columns = [])) where {RawConnType, RawStmtType} =
new(conn, raw, vars, shape)
end

SQLStatement(conn::SQLConnection{RawConnType}, raw::RawStmtType; vars = Symbol[]) where {RawConnType, RawStmtType} =
SQLStatement{RawConnType, RawStmtType}(conn, raw, vars = vars)
SQLStatement(conn::SQLConnection{RawConnType}, raw::RawStmtType; vars = Symbol[], shape = SQLTable(name = :_, columns = [])) where {RawConnType, RawStmtType} =
SQLStatement{RawConnType, RawStmtType}(conn, raw, vars = vars, shape = shape)

function Base.show(io::IO, stmt::SQLStatement)
print(io, "SQLStatement(")
Expand All @@ -51,32 +52,70 @@ function Base.show(io::IO, stmt::SQLStatement)
print(io, ", vars = ")
show(io, stmt.vars)
end
if stmt.shape.name !== :_ || !isempty(stmt.shape.columns)
print(io, ", shape = SQLTable(", stmt.shape.name)
l = length(stmt.shape.columns)
print(io, l == 0 ? ")" : l == 1 ? ", …1 column…)" : ", …$l columns…)")
end
print(io, ')')
end

DataAPI.metadatasupport(::Type{:SQLStatement}) =
DataAPI.metadatasupport(SQLTable)

DataAPI.metadata(stmt::SQLStatement, key; style = false) =
DataAPI.metadata(stmt.shape, key; style)

DataAPI.metadata(stmt::SQLStatement, key, default; style = false) =
DataAPI.metadata(stmt.shape, key, default; style)

DataAPI.metadatakeys(stmt::SQLStatement) =
DataAPI.metadatakeys(stmt.shape)

DataAPI.colmetadatasupport(::Type{:SQLStatement}) =
DataAPI.colmetadatasupport(SQLTable)

DataAPI.colmetadata(stmt::SQLStatement, col, key; style = false) =
DataAPI.colmetadata(stmt.shape, col, key; style)

DataAPI.colmetadata(stmt::SQLStatement, col, key, default; style = false) =
DataAPI.colmetadata(stmt.shape, col, key, default; style)

DataAPI.colmetadatakeys(stmt::SQLStatement) =
DataAPI.colmetadatakeys(stmt.shape)

DataAPI.colmetadatakeys(stmt::SQLStatement, col) =
DataAPI.colmetadatakeys(stmt.shape, col)

"""
Shorthand for [`SQLConnection`](@ref).
"""
const DB = SQLConnection

"""
SQLCursor(raw)
SQLCursor(raw; shape)

Wraps the query result.
"""
struct SQLCursor{RawCrType} <: DBInterface.Cursor
raw::RawCrType
shape::SQLTable

SQLCursor{RawCrType}(raw::RawCrType) where {RawCrType} =
new(raw)
SQLCursor{RawCrType}(raw::RawCrType; shape = SQLTable(name = :_, columns = [])) where {RawCrType} =
new(raw, shape)
end

SQLCursor(raw::RawCrType) where {RawCrType} =
SQLCursor{RawCrType}(raw)
SQLCursor(raw::RawCrType; shape = SQLTable(name = :_, columns = [])) where {RawCrType} =
SQLCursor{RawCrType}(raw; shape)

function Base.show(io::IO, cr::SQLCursor)
print(io, "SQLCursor(")
show(io, cr.raw)
if cr.shape.name !== :_ || !isempty(cr.shape.columns)
print(io, ", shape = SQLTable(", cr.shape.name)
l = length(cr.shape.columns)
print(io, l == 0 ? ")" : l == 1 ? ", …1 column…)" : ", …$l columns…)")
end
print(io, ")")
end

Expand Down Expand Up @@ -110,6 +149,33 @@ Tables.columns(cr::SQLCursor) =
Tables.schema(cr::SQLCursor) =
Tables.schema(cr.raw)

DataAPI.metadatasupport(::Type{<:SQLCursor}) =
DataAPI.metadatasupport(SQLTable)

DataAPI.metadata(cr::SQLCursor, key; style = false) =
DataAPI.metadata(cr.shape, key; style)

DataAPI.metadata(cr::SQLCursor, key, default; style = false) =
DataAPI.metadata(cr.shape, key, default; style)

DataAPI.metadatakeys(cr::SQLCursor) =
DataAPI.metadatakeys(cr.shape)

DataAPI.colmetadatasupport(::Type{<:SQLCursor}) =
DataAPI.colmetadatasupport(SQLTable)

DataAPI.colmetadata(cr::SQLCursor, col, key; style = false) =
DataAPI.colmetadata(cr.shape, col, key; style)

DataAPI.colmetadata(cr::SQLCursor, col, key, default; style = false) =
DataAPI.colmetadata(cr.shape, col, key, default; style)

DataAPI.colmetadatakeys(cr::SQLCursor) =
DataAPI.colmetadatakeys(cr.shape)

DataAPI.colmetadatakeys(cr::SQLCursor, col) =
DataAPI.colmetadatakeys(cr.shape, col)

"""
DBInterface.connect(DB{RawConnType},
args...;
Expand Down Expand Up @@ -150,8 +216,8 @@ DBInterface.prepare(conn::SQLConnection, sql::Union{SQLQuery, SQLSyntax}) =

Generate a prepared SQL statement.
"""
DBInterface.prepare(conn::SQLConnection, str::SQLString) =
SQLStatement(conn, DBInterface.prepare(conn.raw, str.raw), vars = str.vars)
DBInterface.prepare(conn::SQLConnection, sql::SQLString) =
SQLStatement(conn, DBInterface.prepare(conn.raw, sql.raw), vars = sql.vars, shape = sql.shape)

DBInterface.prepare(conn::SQLConnection, str::AbstractString) =
DBInterface.prepare(conn.raw, str)
Expand Down Expand Up @@ -183,7 +249,7 @@ DBInterface.close!(conn::SQLConnection) =
Execute the prepared SQL statement.
"""
DBInterface.execute(stmt::SQLStatement, params) =
SQLCursor(DBInterface.execute(stmt.raw, pack(stmt.vars, params)))
SQLCursor(DBInterface.execute(stmt.raw, pack(stmt.vars, params)), shape = stmt.shape)

DBInterface.getconnection(stmt::SQLStatement) =
stmt.conn
Expand Down
4 changes: 2 additions & 2 deletions src/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ mutable struct SerializeContext <: IO
end

function serialize(s::SQLSyntax)
@dissect(s, WITH_CONTEXT(tail = (local s′), dialect = (local dialect), columns = (local columns))) || throw(IllFormedError())
@dissect(s, WITH_CONTEXT(tail = (local s′), dialect = (local dialect), shape = (local shape))) || throw(IllFormedError())
ctx = SerializeContext(dialect, s′)
serialize!(ctx)
raw = String(take!(ctx.io))
SQLString(raw, columns = columns, vars = ctx.vars)
SQLString(raw, vars = ctx.vars, shape = shape)
end

Base.write(ctx::SerializeContext, octet::UInt8) =
Expand Down
Loading
Loading