@@ -83,16 +83,19 @@ def __init__(self):
83
83
# batches.
84
84
self .dynamics : List [str ] = []
85
85
86
- def visit (self , tree : tsqlParser .Sql_clauseContext ) -> str :
86
+ def visit (
87
+ self , tree : tsqlParser .Sql_clauseContext , prepend_dynamics : bool = True
88
+ ) -> str :
87
89
dynamics = self .dynamics [:]
88
90
89
91
chunks = tree .accept (self )
90
92
91
93
# CREATE SCHEMA/VIEW must be the only statement in a batch
92
- if tree .ddl_clause () is not None and (
94
+ is_create_schema_or_view = tree .ddl_clause () is not None and (
93
95
tree .ddl_clause ().create_schema () is not None
94
96
or tree .ddl_clause ().create_view () is not None
95
- ):
97
+ )
98
+ if not prepend_dynamics or is_create_schema_or_view :
96
99
return " " .join (chunks )
97
100
98
101
return " " .join (dynamics + chunks )
@@ -143,7 +146,7 @@ def syntaxError(
143
146
raise ValueError (f"Error parsing SQL script: { error_message } " )
144
147
145
148
146
- def _split (code : str ) -> List [str ]:
149
+ def _split (code : str , isolate_top_level_statements : bool = True ) -> List [str ]:
147
150
if not USE_CPP_IMPLEMENTATION :
148
151
warnings .warn (
149
152
"Can not find C++ version of the parser, Python version will be used instead."
@@ -160,15 +163,24 @@ def _split(code: str) -> List[str]:
160
163
tree = parse (InputStream (data = code ), "tsql_file" , error_listener )
161
164
visitor = _TSQLVisitor ()
162
165
163
- # Our current definition of a 'batch' is a single top-level SQL clause.
166
+ # Our current definition of a 'batch' in isolation mode is a single top-level SQL clause.
164
167
# Note that this differs from the grammar definition of a batch, which is
165
- # a group of clauses between GO statements
168
+ # a group of clauses between GO statements. The latter matches the definition of batches
169
+ # in non-isolation mode.
166
170
batches = []
167
171
for batch in tree .batch ():
172
+ clauses = []
173
+ first_clause_in_batch = True
168
174
for sql_clause in batch .sql_clauses ().sql_clause ():
169
- batch_query = visitor .visit (sql_clause )
170
- if batch_query != "" :
171
- batches .append (batch_query )
175
+ prepend_dynamics = first_clause_in_batch or isolate_top_level_statements
176
+ clause = visitor .visit (sql_clause , prepend_dynamics = prepend_dynamics )
177
+ if clause != "" :
178
+ clauses .append (clause )
179
+ first_clause_in_batch = False
180
+ if isolate_top_level_statements :
181
+ batches .extend (clauses )
182
+ else :
183
+ batches .append ("\n " .join (clauses ))
172
184
173
185
logging .info ("SQL script parsed successfully." )
174
186
@@ -186,6 +198,7 @@ def executes(
186
198
code : str ,
187
199
engine : sqlalchemy .engine .Engine ,
188
200
parameters : Optional [Dict [str , Any ]] = None ,
201
+ isolate_top_level_statements = True ,
189
202
) -> None :
190
203
"""Execute a given sql string through a sqlalchemy.engine.Engine connection.
191
204
@@ -197,6 +210,7 @@ def executes(
197
210
code T-SQL string to be executed
198
211
engine (sqlalchemy.engine.Engine): established mssql connection
199
212
parameters An optional dictionary of parameters to substituted in the sql script
213
+ isolate_top_level_statements: whether to execute statements one by one or in whole batches
200
214
201
215
Returns
202
216
-------
@@ -209,7 +223,7 @@ def executes(
209
223
# connection is closed. Caveat: sqlalchemy engines can pool connections, so we still have to drop it preemtively.
210
224
conn .execute (f"DROP TABLE IF EXISTS { _PRINTS_TABLE } " )
211
225
conn .execute (f"CREATE TABLE { _PRINTS_TABLE } (p NVARCHAR(4000))" )
212
- for batch in _split (parametrized_code ):
226
+ for batch in _split (parametrized_code , isolate_top_level_statements ):
213
227
conn .execute (batch )
214
228
_fetch_and_clear_prints (conn )
215
229
@@ -218,6 +232,7 @@ def execute(
218
232
path : Union [str , Path ],
219
233
engine : sqlalchemy .engine .Engine ,
220
234
parameters : Optional [Dict [str , Any ]] = None ,
235
+ isolate_top_level_statements = True ,
221
236
encoding : str = "utf-8" ,
222
237
) -> None :
223
238
"""Execute a given sql script through a sqlalchemy.engine.Engine connection.
@@ -227,10 +242,11 @@ def execute(
227
242
path (Path or str): Path to the sql file to be executed
228
243
engine (sqlalchemy.engine.Engine): established mssql connection
229
244
encoding: file encoding of the sql script (default: utf-8)
245
+ isolate_top_level_statements: whether to execute statements one by one or in whole batches
230
246
231
247
Returns
232
248
-------
233
249
None
234
250
235
251
"""
236
- executes (_code (path , encoding ), engine , parameters )
252
+ executes (_code (path , encoding ), engine , parameters , isolate_top_level_statements )
0 commit comments