Skip to content

Commit 35bfd71

Browse files
committed
Support Named Schemas
1 parent e17c5ef commit 35bfd71

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,10 @@ class SpannerSQLCompiler(SQLCompiler):
233233

234234
compound_keywords = _compound_keywords
235235

236+
def __init__(self, *args, **kwargs):
237+
self.tablealiases = {}
238+
super().__init__(*args, **kwargs)
239+
236240
def get_from_hint_text(self, _, text):
237241
"""Return a hint text.
238242
@@ -378,8 +382,10 @@ def limit_clause(self, select, **kw):
378382
return text
379383

380384
def returning_clause(self, stmt, returning_cols, **kw):
385+
# Set include_table=False because although table names are allowed in
386+
# RETURNING clauses, schema names are not.
381387
columns = [
382-
self._label_select_column(None, c, True, False, {})
388+
self._label_select_column(None, c, True, False, {}, include_table=False)
383389
for c in expression._select_iterables(returning_cols)
384390
]
385391

@@ -391,6 +397,66 @@ def visit_sequence(self, seq, **kw):
391397
seq
392398
)
393399

400+
def visit_table(self, table, spanner_aliased=False, iscrud=False, **kwargs):
401+
"""Build the table name.
402+
403+
Schema names are not allowed in Spanner SELECT statements. When selecting
404+
from a schema-qualified table, alias the table to produce SQL like:
405+
406+
SELECT tbl_1.id, tbl_1.col
407+
FROM schema.tbl AS tbl_1
408+
"""
409+
# This closely code mirrors the mssql dialect which also
410+
# avoids schema-qualified columns in SELECTs, although the
411+
# behaviour is currently behind a deprecated
412+
# 'legacy_schema_aliasing' flag.
413+
if spanner_aliased is table or iscrud:
414+
return super().visit_table(table, **kwargs)
415+
416+
# alias schema-qualified tables
417+
alias = self._schema_aliased_table(table)
418+
if alias is not None:
419+
return self.process(alias, spanner_aliased=table, **kwargs)
420+
else:
421+
return super().visit_table(table, **kwargs)
422+
423+
def visit_alias(self, alias, **kw):
424+
# translate for schema-qualified table aliases
425+
kw["spanner_aliased"] = alias.element
426+
return super().visit_alias(alias, **kw)
427+
428+
def visit_column(self, column, add_to_result_map=None, **kw):
429+
if (
430+
column.table is not None
431+
and (not self.isupdate and not self.isdelete and not self.isinsert)
432+
or self.is_subquery()
433+
):
434+
# translate for schema-qualified table aliases
435+
t = self._schema_aliased_table(column.table)
436+
if t is not None:
437+
converted = elements._corresponding_column_or_error(t, column)
438+
if add_to_result_map is not None:
439+
add_to_result_map(
440+
column.name,
441+
column.name,
442+
(column, column.name, column.key),
443+
column.type,
444+
)
445+
446+
return super().visit_column(converted, **kw)
447+
448+
return super().visit_column(
449+
column, add_to_result_map=add_to_result_map, **kw
450+
)
451+
452+
def _schema_aliased_table(self, table):
453+
if getattr(table, "schema", None) is not None:
454+
if table not in self.tablealiases:
455+
self.tablealiases[table] = table.alias()
456+
return self.tablealiases[table]
457+
else:
458+
return None
459+
394460

395461
class SpannerDDLCompiler(DDLCompiler):
396462
"""Spanner DDL statements compiler."""

test/system/test_basics.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@
3131
from sqlalchemy.testing import eq_, is_true
3232
from sqlalchemy.testing.plugin.plugin_base import fixtures
3333

34+
import logging
35+
36+
logging.basicConfig()
37+
logging.getLogger("sqlalchemy.engine").setLevel(logging.INFO)
38+
logging.getLogger("sqlalchemy.pool").setLevel(logging.DEBUG)
3439

3540
class TestBasics(fixtures.TablesTest):
3641
@classmethod
@@ -58,6 +63,16 @@ def define_tables(cls, metadata):
5863
Column("name", String(20)),
5964
)
6065

66+
with cls.bind.begin() as conn:
67+
conn.execute(text('CREATE SCHEMA IF NOT EXISTS schema'))
68+
Table(
69+
"users",
70+
metadata,
71+
Column("ID", Integer, primary_key=True),
72+
Column("name", String(20)),
73+
schema="schema"
74+
)
75+
6176
def test_hello_world(self, connection):
6277
greeting = connection.execute(text("select 'Hello World'"))
6378
eq_("Hello World", greeting.fetchone()[0])
@@ -139,6 +154,12 @@ class User(Base):
139154
ID: Mapped[int] = mapped_column(primary_key=True)
140155
name: Mapped[str] = mapped_column(String(20))
141156

157+
class SchemaUser(Base):
158+
__tablename__ = "users"
159+
__table_args__ = {'schema': 'schema'}
160+
ID: Mapped[int] = mapped_column(primary_key=True)
161+
name: Mapped[str] = mapped_column(String(20))
162+
142163
engine = connection.engine
143164
with Session(engine) as session:
144165
number = Number(
@@ -156,3 +177,13 @@ class User(Base):
156177
users = session.scalars(statement).all()
157178
eq_(1, len(users))
158179
is_true(users[0].ID > 0)
180+
181+
with Session(engine) as session:
182+
user = SchemaUser(name="SchemaTest")
183+
session.add(user)
184+
session.commit()
185+
186+
statement = select(SchemaUser).filter_by(name="SchemaTest")
187+
users = session.scalars(statement).all()
188+
eq_(1, len(users))
189+
is_true(users[0].ID > 0)

0 commit comments

Comments
 (0)