Skip to content

Commit 368c0ba

Browse files
authored
Merge pull request #41 Fix orm for tables in directories from kabulov/fix_orm_for_tables_in_directories_2
2 parents 6dc5578 + db5ce93 commit 368c0ba

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

test/test_orm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import sqlalchemy as sa
3+
from types import MethodType
4+
from sqlalchemy import Column, Integer, Unicode
5+
from sqlalchemy.orm import declarative_base, sessionmaker
6+
from sqlalchemy.testing.fixtures import TablesTest, config
7+
8+
9+
class TestDirectories(TablesTest):
10+
__backend__ = True
11+
12+
def prepare_table(self, engine):
13+
base = declarative_base()
14+
15+
class Table(base):
16+
__tablename__ = "dir/test"
17+
id = Column(Integer, primary_key=True)
18+
text = Column(Unicode)
19+
20+
base.metadata.create_all(engine)
21+
session = sessionmaker(bind=engine)()
22+
session.add(Table(id=2, text="foo"))
23+
session.commit()
24+
return base, Table, session
25+
26+
def try_update(self, session, Table):
27+
row = session.query(Table).first()
28+
row.text = "bar"
29+
session.commit()
30+
return row
31+
32+
def drop_table(self, base, engine):
33+
base.metadata.drop_all(engine)
34+
35+
def bind_old_method_to_dialect(self, dialect):
36+
def _handle_column_name(self, variable):
37+
return variable
38+
39+
dialect._handle_column_name = MethodType(_handle_column_name, dialect)
40+
41+
def test_directories(self):
42+
engine_good = sa.create_engine(config.db_url)
43+
base, Table, session = self.prepare_table(engine_good)
44+
row = self.try_update(session, Table)
45+
assert row.id == 2
46+
assert row.text == "bar"
47+
self.drop_table(base, engine_good)
48+
49+
engine_bad = sa.create_engine(config.db_url)
50+
self.bind_old_method_to_dialect(engine_bad.dialect)
51+
base, Table, session = self.prepare_table(engine_bad)
52+
with pytest.raises(Exception) as excinfo:
53+
self.try_update(session, Table)
54+
assert "Unknown name: $dir" in str(excinfo.value)
55+
self.drop_table(base, engine_bad)

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,13 @@ class YqlDialect(StrCompileDialect):
584584
def import_dbapi(cls: Any):
585585
return dbapi.YdbDBApi()
586586

587-
def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs):
587+
def __init__(
588+
self,
589+
json_serializer=None,
590+
json_deserializer=None,
591+
_add_declare_for_yql_stmt_vars=False,
592+
**kwargs,
593+
):
588594
super().__init__(**kwargs)
589595

590596
self._json_deserializer = json_deserializer
@@ -673,6 +679,9 @@ def do_rollback(self, dbapi_connection: dbapi.Connection) -> None:
673679
def do_commit(self, dbapi_connection: dbapi.Connection) -> None:
674680
dbapi_connection.commit()
675681

682+
def _handle_column_name(self, variable):
683+
return "`" + variable + "`"
684+
676685
def _format_variables(
677686
self,
678687
statement: str,
@@ -694,15 +703,20 @@ def _format_variables(
694703
variable_names = set(parameters.keys())
695704
formatted_parameters = {f"${k}": v for k, v in parameters.items()}
696705

697-
formatted_variable_names = {variable_name: f"${variable_name}" for variable_name in variable_names}
706+
formatted_variable_names = {
707+
variable_name: f"${self._handle_column_name(variable_name)}" for variable_name in variable_names
708+
}
698709
formatted_statement = formatted_statement % formatted_variable_names
699710

700711
formatted_statement = formatted_statement.replace("%%", "%")
701712
return formatted_statement, formatted_parameters
702713

703714
def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
704715
declarations = "\n".join(
705-
[f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()]
716+
[
717+
f"DECLARE $`{param_name[1:]}` as {str(param_type)};"
718+
for param_name, param_type in parameters_types.items()
719+
]
706720
)
707721
return f"{declarations}\n{statement}"
708722

0 commit comments

Comments
 (0)