@@ -78,28 +78,25 @@ from heapq import heappop
78
78
79
79
def apply(loop=None):
80
80
'''Patch asyncio to make its event loop reentrant.'''
81
- loop = loop or asyncio.get_event_loop()
82
- if not isinstance(loop, asyncio.BaseEventLoop):
83
- raise ValueError('Can\'t patch loop of type %s' % type(loop))
84
- if getattr(loop, '_nest_patched', None):
85
- # already patched
86
- return
87
81
_patch_asyncio()
88
- _patch_loop(loop)
89
82
_patch_task()
90
83
_patch_tornado()
91
84
85
+ loop = loop or asyncio.get_event_loop()
86
+ _patch_loop(loop)
87
+
92
88
93
89
def _patch_asyncio():
94
90
'''
95
91
Patch asyncio module to use pure Python tasks and futures,
96
92
use module level _current_tasks, all_tasks and patch run method.
97
93
'''
98
94
def run(main, *, debug=False):
99
- loop = events._get_running_loop()
100
- if not loop:
101
- loop = events.new_event_loop()
102
- events.set_event_loop(loop)
95
+ try:
96
+ loop = asyncio.get_event_loop()
97
+ except RuntimeError:
98
+ loop = asyncio.new_event_loop()
99
+ asyncio.set_event_loop(loop)
103
100
_patch_loop(loop)
104
101
loop.set_debug(debug)
105
102
task = asyncio.ensure_future(main)
@@ -111,6 +108,14 @@ def _patch_asyncio():
111
108
with suppress(asyncio.CancelledError):
112
109
loop.run_until_complete(task)
113
110
111
+ def _get_event_loop(stacklevel=3):
112
+ loop = events._get_running_loop()
113
+ if loop is None:
114
+ loop = events.get_event_loop_policy().get_event_loop()
115
+ return loop
116
+
117
+ if hasattr(asyncio, '_nest_patched'):
118
+ return
114
119
if sys.version_info >= (3, 6, 0):
115
120
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \
116
121
asyncio.tasks._PyTask
@@ -119,9 +124,12 @@ def _patch_asyncio():
119
124
if sys.version_info < (3, 7, 0):
120
125
asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks
121
126
asyncio.all_tasks = asyncio.tasks.Task.all_tasks
122
- if not hasattr(asyncio, '_run_orig'):
123
- asyncio._run_orig = getattr(asyncio, 'run', None)
124
- asyncio.run = run
127
+ if sys.version_info >= (3, 9, 0):
128
+ events._get_event_loop = events.get_event_loop = \
129
+ asyncio.get_event_loop = _get_event_loop
130
+ _get_event_loop
131
+ asyncio.run = run
132
+ asyncio._nest_patched = True
125
133
126
134
127
135
def _patch_loop(loop):
@@ -229,18 +237,22 @@ def _patch_loop(loop):
229
237
'''Do not throw exception if loop is already running.'''
230
238
pass
231
239
240
+ if hasattr(loop, '_nest_patched'):
241
+ return
242
+ if not isinstance(loop, asyncio.BaseEventLoop):
243
+ raise ValueError('Can\'t patch loop of type %s' % type(loop))
232
244
cls = loop.__class__
233
245
cls.run_forever = run_forever
234
246
cls.run_until_complete = run_until_complete
235
247
cls._run_once = _run_once
236
248
cls._check_running = _check_running
237
249
cls._check_runnung = _check_running # typo in Python 3.7 source
238
- cls._nest_patched = True
239
250
cls._num_runs_pending = 0
240
251
cls._is_proactorloop = (
241
252
os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop))
242
253
if sys.version_info < (3, 7, 0):
243
254
cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
255
+ cls._nest_patched = True
244
256
245
257
246
258
def _patch_task():
@@ -257,6 +269,8 @@ def _patch_task():
257
269
curr_tasks[task._loop] = curr_task
258
270
259
271
Task = asyncio.Task
272
+ if hasattr(Task, '_nest_patched'):
273
+ return
260
274
if sys.version_info >= (3, 7, 0):
261
275
262
276
def enter_task(loop, task):
@@ -274,6 +288,7 @@ def _patch_task():
274
288
curr_tasks = Task._current_tasks
275
289
step_orig = Task._step
276
290
Task._step = step
291
+ Task._nest_patched = True
277
292
278
293
279
294
def _patch_tornado():
@@ -457,20 +472,20 @@ def _compile_ast(node: ast.AST, filename: str = "<eval>", mode: str = "exec") ->
457
472
ASTWithBody = Union[ast.Module, ast.With, ast.AsyncWith]
458
473
459
474
460
- def _make_stmt_as_return(parent: ASTWithBody, root: ast.AST) -> types.CodeType:
475
+ def _make_stmt_as_return(parent: ASTWithBody, root: ast.AST, filename: str ) -> types.CodeType:
461
476
node = parent.body[-1]
462
477
463
478
if isinstance(node, ast.Expr):
464
479
parent.body[-1] = ast.copy_location(ast.Return(node.value), node)
465
480
466
481
try:
467
- return _compile_ast(root)
482
+ return _compile_ast(root, filename )
468
483
except (SyntaxError, TypeError): # pragma: no cover # TODO: found case to cover except body
469
484
parent.body[-1] = node
470
- return _compile_ast(root)
485
+ return _compile_ast(root, filename )
471
486
472
487
473
- def _transform_to_async(code: str) -> types.CodeType:
488
+ def _transform_to_async(code: str, filename: str ) -> types.CodeType:
474
489
base: ast.Module = ast.parse(_ASYNC_EVAL_CODE_TEMPLATE)
475
490
module: ast.Module = cast(ast.Module, _parse_code(code))
476
491
@@ -483,7 +498,7 @@ def _transform_to_async(code: str) -> types.CodeType:
483
498
while isinstance(parent.body[-1], (ast.AsyncWith, ast.With)):
484
499
parent = cast(ASTWithBody, parent.body[-1])
485
500
486
- return _make_stmt_as_return(parent, base)
501
+ return _make_stmt_as_return(parent, base, filename )
487
502
488
503
489
504
class _AsyncNodeFound(Exception):
@@ -544,7 +559,13 @@ def is_async_code(code: str) -> bool:
544
559
545
560
546
561
# async equivalent of builtin eval function
547
- def async_eval(code: str, _globals: Optional[dict] = None, _locals: Optional[dict] = None) -> Any:
562
+ def async_eval(
563
+ code: str,
564
+ _globals: Optional[dict] = None,
565
+ _locals: Optional[dict] = None,
566
+ *,
567
+ filename: str = "<eval>",
568
+ ) -> Any:
548
569
apply() # double check that loop is patched
549
570
550
571
caller: types.FrameType = inspect.currentframe().f_back # type: ignore
@@ -555,7 +576,7 @@ def async_eval(code: str, _globals: Optional[dict] = None, _locals: Optional[dic
555
576
if _globals is None:
556
577
_globals = caller.f_globals
557
578
558
- code_obj = _transform_to_async(code)
579
+ code_obj = _transform_to_async(code, filename )
559
580
560
581
try:
561
582
exec(code_obj, _globals, _locals)
0 commit comments