Skip to content

Commit ab12386

Browse files
Introduce isolation mode option (#35)
Isolation mode option
1 parent c99cf8b commit ab12386

File tree

4 files changed

+109
-11
lines changed

4 files changed

+109
-11
lines changed

src/pytsql/tsql.py

+27-11
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,19 @@ def __init__(self):
8383
# batches.
8484
self.dynamics: List[str] = []
8585

86-
def visit(self, tree: tsqlParser.Sql_clauseContext) -> str:
86+
def visit(
87+
self, tree: tsqlParser.Sql_clauseContext, prepend_dynamics: bool = True
88+
) -> str:
8789
dynamics = self.dynamics[:]
8890

8991
chunks = tree.accept(self)
9092

9193
# CREATE SCHEMA/VIEW must be the only statement in a batch
92-
if tree.ddl_clause() is not None and (
94+
is_create_schema_or_view = tree.ddl_clause() is not None and (
9395
tree.ddl_clause().create_schema() is not None
9496
or tree.ddl_clause().create_view() is not None
95-
):
97+
)
98+
if not prepend_dynamics or is_create_schema_or_view:
9699
return " ".join(chunks)
97100

98101
return " ".join(dynamics + chunks)
@@ -143,7 +146,7 @@ def syntaxError(
143146
raise ValueError(f"Error parsing SQL script: {error_message}")
144147

145148

146-
def _split(code: str) -> List[str]:
149+
def _split(code: str, isolate_top_level_statements: bool = True) -> List[str]:
147150
if not USE_CPP_IMPLEMENTATION:
148151
warnings.warn(
149152
"Can not find C++ version of the parser, Python version will be used instead."
@@ -160,15 +163,24 @@ def _split(code: str) -> List[str]:
160163
tree = parse(InputStream(data=code), "tsql_file", error_listener)
161164
visitor = _TSQLVisitor()
162165

163-
# Our current definition of a 'batch' is a single top-level SQL clause.
166+
# Our current definition of a 'batch' in isolation mode is a single top-level SQL clause.
164167
# Note that this differs from the grammar definition of a batch, which is
165-
# a group of clauses between GO statements
168+
# a group of clauses between GO statements. The latter matches the definition of batches
169+
# in non-isolation mode.
166170
batches = []
167171
for batch in tree.batch():
172+
clauses = []
173+
first_clause_in_batch = True
168174
for sql_clause in batch.sql_clauses().sql_clause():
169-
batch_query = visitor.visit(sql_clause)
170-
if batch_query != "":
171-
batches.append(batch_query)
175+
prepend_dynamics = first_clause_in_batch or isolate_top_level_statements
176+
clause = visitor.visit(sql_clause, prepend_dynamics=prepend_dynamics)
177+
if clause != "":
178+
clauses.append(clause)
179+
first_clause_in_batch = False
180+
if isolate_top_level_statements:
181+
batches.extend(clauses)
182+
else:
183+
batches.append("\n".join(clauses))
172184

173185
logging.info("SQL script parsed successfully.")
174186

@@ -186,6 +198,7 @@ def executes(
186198
code: str,
187199
engine: sqlalchemy.engine.Engine,
188200
parameters: Optional[Dict[str, Any]] = None,
201+
isolate_top_level_statements=True,
189202
) -> None:
190203
"""Execute a given sql string through a sqlalchemy.engine.Engine connection.
191204
@@ -197,6 +210,7 @@ def executes(
197210
code T-SQL string to be executed
198211
engine (sqlalchemy.engine.Engine): established mssql connection
199212
parameters An optional dictionary of parameters to substituted in the sql script
213+
isolate_top_level_statements: whether to execute statements one by one or in whole batches
200214
201215
Returns
202216
-------
@@ -209,7 +223,7 @@ def executes(
209223
# connection is closed. Caveat: sqlalchemy engines can pool connections, so we still have to drop it preemtively.
210224
conn.execute(f"DROP TABLE IF EXISTS {_PRINTS_TABLE}")
211225
conn.execute(f"CREATE TABLE {_PRINTS_TABLE} (p NVARCHAR(4000))")
212-
for batch in _split(parametrized_code):
226+
for batch in _split(parametrized_code, isolate_top_level_statements):
213227
conn.execute(batch)
214228
_fetch_and_clear_prints(conn)
215229

@@ -218,6 +232,7 @@ def execute(
218232
path: Union[str, Path],
219233
engine: sqlalchemy.engine.Engine,
220234
parameters: Optional[Dict[str, Any]] = None,
235+
isolate_top_level_statements=True,
221236
encoding: str = "utf-8",
222237
) -> None:
223238
"""Execute a given sql script through a sqlalchemy.engine.Engine connection.
@@ -227,10 +242,11 @@ def execute(
227242
path (Path or str): Path to the sql file to be executed
228243
engine (sqlalchemy.engine.Engine): established mssql connection
229244
encoding: file encoding of the sql script (default: utf-8)
245+
isolate_top_level_statements: whether to execute statements one by one or in whole batches
230246
231247
Returns
232248
-------
233249
None
234250
235251
"""
236-
executes(_code(path, encoding), engine, parameters)
252+
executes(_code(path, encoding), engine, parameters, isolate_top_level_statements)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import logging
2+
3+
from pytsql import executes
4+
5+
6+
def test_rowcount(engine, caplog):
7+
caplog.set_level(logging.INFO)
8+
9+
seed = """
10+
USE [tempdb]
11+
GO
12+
DROP TABLE IF EXISTS [test_table]
13+
CREATE TABLE [test_table] (
14+
col VARCHAR(3)
15+
)
16+
GO
17+
INSERT INTO [test_table] (col)
18+
VALUES ('A'), ('AB'), ('ABC')
19+
PRINT('Affected ' + CAST(@@ROWCOUNT AS VARCHAR) + ' rows')
20+
"""
21+
22+
executes(seed, engine, isolate_top_level_statements=False)
23+
24+
assert "Affected 3 rows" in caplog.text
25+
26+
27+
def test_semi_persistent_set(engine, caplog):
28+
caplog.set_level(logging.INFO)
29+
30+
seed = """
31+
DECLARE @A INT = 12
32+
DECLARE @B INT = 34
33+
SET @A = 56
34+
SET @B = 78
35+
PRINT(@A)
36+
GO
37+
PRINT(@B)
38+
"""
39+
40+
executes(seed, engine, isolate_top_level_statements=False)
41+
42+
assert "56" in caplog.text
43+
assert "34" in caplog.text
44+
assert "12" not in caplog.text
45+
assert "78" not in caplog.text

tests/unit/test_dynamics.py

+25
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,28 @@ def test_dont_append_dynamics_on_create_schema():
4444
""",
4545
)
4646
assert splits[3] == "CREATE VIEW y AS SELECT 1"
47+
48+
49+
def test_dynamics_no_isolation():
50+
seed = """
51+
DECLARE @A INT = 5
52+
SELECT @A
53+
GO
54+
SELECT @A
55+
"""
56+
splits = _split(seed, isolate_top_level_statements=False)
57+
assert len(splits) == 2
58+
assert_strings_equal_disregarding_whitespace(
59+
splits[0],
60+
"""
61+
DECLARE @A INT = 5
62+
SELECT @A
63+
""",
64+
)
65+
assert_strings_equal_disregarding_whitespace(
66+
splits[1],
67+
"""
68+
DECLARE @A INT = 5
69+
SELECT @A
70+
""",
71+
)

tests/unit/test_split.py

+12
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,15 @@ def test_linked_server():
313313
FROM linked_server.db.schm.table2
314314
"""
315315
assert len(_split(seed)) == 1
316+
317+
318+
def test_rowcount_no_isolation():
319+
seed = """
320+
USE new_db
321+
PRINT('Using new_db')
322+
GO
323+
SELECT * FROM my_table
324+
PRINT(@@ROWCOUNT)
325+
"""
326+
assert len(_split(seed)) == 4
327+
assert len(_split(seed, isolate_top_level_statements=False)) == 2

0 commit comments

Comments
 (0)