Skip to content

Commit 8c3756e

Browse files
committed
🔧 Update PYDEVD_ASYNC_PLUGIN
1 parent a4d9db7 commit 8c3756e

File tree

1 file changed

+43
-22
lines changed

1 file changed

+43
-22
lines changed

src/main/java/com/uriyyo/evaluate_async_code/AsyncPyDebugUtils.kt

+43-22
Original file line numberDiff line numberDiff line change
@@ -78,28 +78,25 @@ from heapq import heappop
7878
7979
def apply(loop=None):
8080
'''Patch asyncio to make its event loop reentrant.'''
81-
loop = loop or asyncio.get_event_loop()
82-
if not isinstance(loop, asyncio.BaseEventLoop):
83-
raise ValueError('Can\'t patch loop of type %s' % type(loop))
84-
if getattr(loop, '_nest_patched', None):
85-
# already patched
86-
return
8781
_patch_asyncio()
88-
_patch_loop(loop)
8982
_patch_task()
9083
_patch_tornado()
9184
85+
loop = loop or asyncio.get_event_loop()
86+
_patch_loop(loop)
87+
9288
9389
def _patch_asyncio():
9490
'''
9591
Patch asyncio module to use pure Python tasks and futures,
9692
use module level _current_tasks, all_tasks and patch run method.
9793
'''
9894
def run(main, *, debug=False):
99-
loop = events._get_running_loop()
100-
if not loop:
101-
loop = events.new_event_loop()
102-
events.set_event_loop(loop)
95+
try:
96+
loop = asyncio.get_event_loop()
97+
except RuntimeError:
98+
loop = asyncio.new_event_loop()
99+
asyncio.set_event_loop(loop)
103100
_patch_loop(loop)
104101
loop.set_debug(debug)
105102
task = asyncio.ensure_future(main)
@@ -111,6 +108,14 @@ def _patch_asyncio():
111108
with suppress(asyncio.CancelledError):
112109
loop.run_until_complete(task)
113110
111+
def _get_event_loop(stacklevel=3):
112+
loop = events._get_running_loop()
113+
if loop is None:
114+
loop = events.get_event_loop_policy().get_event_loop()
115+
return loop
116+
117+
if hasattr(asyncio, '_nest_patched'):
118+
return
114119
if sys.version_info >= (3, 6, 0):
115120
asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \
116121
asyncio.tasks._PyTask
@@ -119,9 +124,12 @@ def _patch_asyncio():
119124
if sys.version_info < (3, 7, 0):
120125
asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks
121126
asyncio.all_tasks = asyncio.tasks.Task.all_tasks
122-
if not hasattr(asyncio, '_run_orig'):
123-
asyncio._run_orig = getattr(asyncio, 'run', None)
124-
asyncio.run = run
127+
if sys.version_info >= (3, 9, 0):
128+
events._get_event_loop = events.get_event_loop = \
129+
asyncio.get_event_loop = _get_event_loop
130+
_get_event_loop
131+
asyncio.run = run
132+
asyncio._nest_patched = True
125133
126134
127135
def _patch_loop(loop):
@@ -229,18 +237,22 @@ def _patch_loop(loop):
229237
'''Do not throw exception if loop is already running.'''
230238
pass
231239
240+
if hasattr(loop, '_nest_patched'):
241+
return
242+
if not isinstance(loop, asyncio.BaseEventLoop):
243+
raise ValueError('Can\'t patch loop of type %s' % type(loop))
232244
cls = loop.__class__
233245
cls.run_forever = run_forever
234246
cls.run_until_complete = run_until_complete
235247
cls._run_once = _run_once
236248
cls._check_running = _check_running
237249
cls._check_runnung = _check_running # typo in Python 3.7 source
238-
cls._nest_patched = True
239250
cls._num_runs_pending = 0
240251
cls._is_proactorloop = (
241252
os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop))
242253
if sys.version_info < (3, 7, 0):
243254
cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper
255+
cls._nest_patched = True
244256
245257
246258
def _patch_task():
@@ -257,6 +269,8 @@ def _patch_task():
257269
curr_tasks[task._loop] = curr_task
258270
259271
Task = asyncio.Task
272+
if hasattr(Task, '_nest_patched'):
273+
return
260274
if sys.version_info >= (3, 7, 0):
261275
262276
def enter_task(loop, task):
@@ -274,6 +288,7 @@ def _patch_task():
274288
curr_tasks = Task._current_tasks
275289
step_orig = Task._step
276290
Task._step = step
291+
Task._nest_patched = True
277292
278293
279294
def _patch_tornado():
@@ -457,20 +472,20 @@ def _compile_ast(node: ast.AST, filename: str = "<eval>", mode: str = "exec") ->
457472
ASTWithBody = Union[ast.Module, ast.With, ast.AsyncWith]
458473
459474
460-
def _make_stmt_as_return(parent: ASTWithBody, root: ast.AST) -> types.CodeType:
475+
def _make_stmt_as_return(parent: ASTWithBody, root: ast.AST, filename: str) -> types.CodeType:
461476
node = parent.body[-1]
462477
463478
if isinstance(node, ast.Expr):
464479
parent.body[-1] = ast.copy_location(ast.Return(node.value), node)
465480
466481
try:
467-
return _compile_ast(root)
482+
return _compile_ast(root, filename)
468483
except (SyntaxError, TypeError): # pragma: no cover # TODO: found case to cover except body
469484
parent.body[-1] = node
470-
return _compile_ast(root)
485+
return _compile_ast(root, filename)
471486
472487
473-
def _transform_to_async(code: str) -> types.CodeType:
488+
def _transform_to_async(code: str, filename: str) -> types.CodeType:
474489
base: ast.Module = ast.parse(_ASYNC_EVAL_CODE_TEMPLATE)
475490
module: ast.Module = cast(ast.Module, _parse_code(code))
476491
@@ -483,7 +498,7 @@ def _transform_to_async(code: str) -> types.CodeType:
483498
while isinstance(parent.body[-1], (ast.AsyncWith, ast.With)):
484499
parent = cast(ASTWithBody, parent.body[-1])
485500
486-
return _make_stmt_as_return(parent, base)
501+
return _make_stmt_as_return(parent, base, filename)
487502
488503
489504
class _AsyncNodeFound(Exception):
@@ -544,7 +559,13 @@ def is_async_code(code: str) -> bool:
544559
545560
546561
# async equivalent of builtin eval function
547-
def async_eval(code: str, _globals: Optional[dict] = None, _locals: Optional[dict] = None) -> Any:
562+
def async_eval(
563+
code: str,
564+
_globals: Optional[dict] = None,
565+
_locals: Optional[dict] = None,
566+
*,
567+
filename: str = "<eval>",
568+
) -> Any:
548569
apply() # double check that loop is patched
549570
550571
caller: types.FrameType = inspect.currentframe().f_back # type: ignore
@@ -555,7 +576,7 @@ def async_eval(code: str, _globals: Optional[dict] = None, _locals: Optional[dic
555576
if _globals is None:
556577
_globals = caller.f_globals
557578
558-
code_obj = _transform_to_async(code)
579+
code_obj = _transform_to_async(code, filename)
559580
560581
try:
561582
exec(code_obj, _globals, _locals)

0 commit comments

Comments
 (0)