Skip to content

Commit d6264e6

Browse files
committed
SA14: Adjust CrateDB dialect compiler patch to SqlAlchemy 1.4.36
The original code for `visit_update` and `_get_crud_params` from SQLAlchemy 1.4.0b1 has been vendored into the CrateDB dialect the other day, in order to amend it due to dialect-specific purposes. This patch reflects the changes from SA 1.4.0b1 to SA 1.4.36 on this code.
1 parent 7238b3b commit d6264e6

File tree

3 files changed

+26
-10
lines changed

3 files changed

+26
-10
lines changed

src/crate/client/sqlalchemy/compiler.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
from collections import defaultdict
2424

2525
import sqlalchemy as sa
26-
from sqlalchemy.sql import crud, selectable
27-
from sqlalchemy.sql import compiler
2826
from .types import MutableDict
2927
from .sa_version import SA_VERSION, SA_1_1, SA_1_4
3028

29+
crud = sa.sql.crud
30+
selectable = sa.sql.selectable
31+
compiler = sa.sql.compiler
32+
3133

3234
INSERT_SELECT_WITHOUT_PARENTHESES_MIN_VERSION = (1, 0, 1)
3335

@@ -521,6 +523,10 @@ def visit_update_14(self, update_stmt, **kw):
521523
else:
522524
dialect_hints = None
523525

526+
if update_stmt._independent_ctes:
527+
for cte in update_stmt._independent_ctes:
528+
cte._compiler_dispatch(self, **kw)
529+
524530
text += table_text
525531

526532
text += " SET "
@@ -580,8 +586,9 @@ def visit_update_14(self, update_stmt, **kw):
580586
update_stmt, self.returning or update_stmt._returning
581587
)
582588

583-
if self.ctes and toplevel:
584-
text = self._render_cte_clause() + text
589+
if self.ctes:
590+
nesting_level = len(self.stack) if not toplevel else None
591+
text = self._render_cte_clause(nesting_level=nesting_level) + text
585592

586593
self.stack.pop(-1)
587594

@@ -602,7 +609,7 @@ def _get_crud_params_14(compiler, stmt, compile_state, **kw):
602609
from sqlalchemy.sql.crud import _create_bind_param
603610
from sqlalchemy.sql.crud import REQUIRED
604611
from sqlalchemy.sql.crud import _get_stmt_parameter_tuples_params
605-
from sqlalchemy.sql.crud import _get_multitable_params
612+
from sqlalchemy.sql.crud import _get_update_multitable_params
606613
from sqlalchemy.sql.crud import _scan_insert_from_select_cols
607614
from sqlalchemy.sql.crud import _scan_cols
608615
from sqlalchemy import exc # noqa: F401
@@ -682,7 +689,7 @@ def _get_crud_params_14(compiler, stmt, compile_state, **kw):
682689
# special logic that only occurs for multi-table UPDATE
683690
# statements
684691
if compile_state.isupdate and compile_state.is_multitable:
685-
_get_multitable_params(
692+
_get_update_multitable_params(
686693
compiler,
687694
stmt,
688695
compile_state,
@@ -738,9 +745,18 @@ def _get_crud_params_14(compiler, stmt, compile_state, **kw):
738745

739746
if compile_state._has_multi_parameters:
740747
values = _extend_values_for_multiparams(
741-
compiler, stmt, compile_state, values, kw
748+
compiler,
749+
stmt,
750+
compile_state,
751+
values,
752+
_column_as_key,
753+
kw,
742754
)
743-
elif not values and compiler.for_executemany:
755+
elif (
756+
not values
757+
and compiler.for_executemany # noqa: W503
758+
and compiler.dialect.supports_default_metavalue # noqa: W503
759+
):
744760
# convert an "INSERT DEFAULT VALUES"
745761
# into INSERT (firstcol) VALUES (DEFAULT) which can be turned
746762
# into an in-place multi values. This supports

src/crate/client/sqlalchemy/dialect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# FIXME: Workaround to be able to use SQLAlchemy 1.4.
2727
# Caveat: This purges the ``cresultproxy`` extension
2828
# at runtime, so it will impose a speed bump.
29-
import crate.client.sqlalchemy.monkey # noqa:F401
29+
import crate.client.sqlalchemy.monkey # noqa:F401, lgtm [py/unused-import]
3030

3131
from sqlalchemy import types as sqltypes
3232
from sqlalchemy.engine import default, reflection

src/crate/client/sqlalchemy/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class Any(expression.ColumnElement):
168168

169169
def __init__(self, left, right, operator=operators.eq):
170170
self.type = sqltypes.Boolean()
171-
self.left = expression._literal_as_binds(left)
171+
self.left = expression.literal(left)
172172
self.right = right
173173
self.operator = operator
174174

0 commit comments

Comments
 (0)