Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-128415: store current task on loop #128416

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Include/internal/pycore_global_objects_fini_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Include/internal/pycore_global_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ struct _Py_global_strings {
STRUCT_FOR_ID(_blksize)
STRUCT_FOR_ID(_bootstrap)
STRUCT_FOR_ID(_check_retval_)
STRUCT_FOR_ID(_current_task)
STRUCT_FOR_ID(_dealloc_warn)
STRUCT_FOR_ID(_feature_version)
STRUCT_FOR_ID(_field_types)
Expand Down
1 change: 1 addition & 0 deletions Include/internal/pycore_runtime_init_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions Include/internal/pycore_unicodeobject_generated.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

49 changes: 27 additions & 22 deletions Lib/asyncio/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def current_task(loop=None):
"""Return a currently executed task."""
if loop is None:
loop = events.get_running_loop()
return _current_tasks.get(loop)

try:
return loop._current_task
except AttributeError:
return None


def all_tasks(loop=None):
Expand Down Expand Up @@ -1024,10 +1028,6 @@ def factory(loop, coro, *, name=None, context=None):
_scheduled_tasks = weakref.WeakSet()
_eager_tasks = set()

# Dictionary containing tasks that are currently active in
# all running event loops. {EventLoop: Task}
_current_tasks = {}


def _register_task(task):
"""Register an asyncio Task scheduled to run on an event loop."""
Expand All @@ -1040,29 +1040,34 @@ def _register_eager_task(task):


def _enter_task(loop, task):
current_task = _current_tasks.get(loop)
if current_task is not None:
raise RuntimeError(f"Cannot enter into task {task!r} while another "
f"task {current_task!r} is being executed.")
_current_tasks[loop] = task
try:
if loop._current_task is not None:
raise RuntimeError(f"Cannot enter into task {task!r} while another "
f"task {loop._current_task!r} is being executed.")
Comment on lines +1045 to +1046
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise RuntimeError(f"Cannot enter into task {task!r} while another "
f"task {loop._current_task!r} is being executed.")
raise RuntimeError(f"Cannot enter into task {task!r} while another "
f"task {loop._current_task!r} is being executed.")

except AttributeError:
pass
loop._current_task = task


def _leave_task(loop, task):
current_task = _current_tasks.get(loop)
if current_task is not task:
try:
if loop._current_task is not task:
raise RuntimeError(f"Leaving task {task!r} does not match "
f"the current task {loop._current_task!r}.")
Comment on lines +1055 to +1056
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise RuntimeError(f"Leaving task {task!r} does not match "
f"the current task {loop._current_task!r}.")
raise RuntimeError(f"Leaving task {task!r} does not match "
f"the current task {loop._current_task!r}.")

except AttributeError:
raise RuntimeError(f"Leaving task {task!r} does not match "
f"the current task {current_task!r}.")
del _current_tasks[loop]
f"the current task.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you align this one?

else:
loop._current_task = None


def _swap_current_task(loop, task):
prev_task = _current_tasks.get(loop)
if task is None:
del _current_tasks[loop]
else:
_current_tasks[loop] = task
return prev_task

try:
prev_task = loop._current_task
loop._current_task = task
return prev_task
except AttributeError:
loop._current_task = task

def _unregister_task(task):
"""Unregister a completed, scheduled Task."""
Expand All @@ -1088,7 +1093,7 @@ def _unregister_eager_task(task):
from _asyncio import (_register_task, _register_eager_task,
_unregister_task, _unregister_eager_task,
_enter_task, _leave_task, _swap_current_task,
_scheduled_tasks, _eager_tasks, _current_tasks,
_scheduled_tasks, _eager_tasks,
current_task, all_tasks)
except ImportError:
pass
Expand Down
21 changes: 17 additions & 4 deletions Lib/test/test_asyncio/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3002,7 +3002,9 @@ def done(self):

def test__enter_task(self):
task = mock.Mock()
loop = mock.Mock()
class LoopLike:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a module-level class would be better?

pass
loop = LoopLike()
self.assertIsNone(asyncio.current_task(loop))
self._enter_task(loop, task)
self.assertIs(asyncio.current_task(loop), task)
Expand All @@ -3011,7 +3013,9 @@ def test__enter_task(self):
def test__enter_task_failure(self):
task1 = mock.Mock()
task2 = mock.Mock()
loop = mock.Mock()
class LoopLike:
pass
loop = LoopLike()
self._enter_task(loop, task1)
with self.assertRaises(RuntimeError):
self._enter_task(loop, task2)
Expand All @@ -3021,14 +3025,20 @@ def test__enter_task_failure(self):
def test__leave_task(self):
task = mock.Mock()
loop = mock.Mock()
class LoopLike:
pass
loop = LoopLike()
self._enter_task(loop, task)
self._leave_task(loop, task)
self.assertIsNone(asyncio.current_task(loop))

def test__leave_task_failure1(self):
task1 = mock.Mock()
task2 = mock.Mock()
loop = mock.Mock()

class LoopLike:
pass
loop = LoopLike()
self._enter_task(loop, task1)
with self.assertRaises(RuntimeError):
self._leave_task(loop, task2)
Expand All @@ -3037,7 +3047,10 @@ def test__leave_task_failure1(self):

def test__leave_task_failure2(self):
task = mock.Mock()
loop = mock.Mock()

class LoopLike:
pass
loop = LoopLike()
with self.assertRaises(RuntimeError):
self._leave_task(loop, task)
self.assertIsNone(asyncio.current_task(loop))
Expand Down
122 changes: 45 additions & 77 deletions Modules/_asynciomodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ typedef struct {
PyObject *asyncio_mod;
PyObject *context_kwname;

/* Dictionary containing tasks that are currently active in
all running event loops. {EventLoop: Task} */
PyObject *current_tasks;

/* WeakSet containing scheduled 3rd party tasks which don't
inherit from native asyncio.Task */
PyObject *non_asyncio_tasks;
Expand Down Expand Up @@ -1926,11 +1922,10 @@ static int
enter_task(asyncio_state *state, PyObject *loop, PyObject *task)
{
PyObject *item;
int res = PyDict_SetDefaultRef(state->current_tasks, loop, task, &item);
if (res < 0) {
if (PyObject_GetOptionalAttr(loop, &_Py_ID(_current_task), &item) < 0) {
return -1;
}
else if (res == 1) {
if (item != NULL && item != Py_None) {
PyErr_Format(
PyExc_RuntimeError,
"Cannot enter into task %R while another " \
Expand All @@ -1939,84 +1934,63 @@ enter_task(asyncio_state *state, PyObject *loop, PyObject *task)
Py_DECREF(item);
return -1;
}
Py_DECREF(item);
return 0;
}

static int
err_leave_task(PyObject *item, PyObject *task)
{
PyErr_Format(
PyExc_RuntimeError,
"Leaving task %R does not match the current task %R.",
task, item);
return -1;
}

static int
leave_task_predicate(PyObject *item, void *task)
{
if (item != task) {
return err_leave_task(item, (PyObject *)task);
if (PyObject_SetAttr(loop, &_Py_ID(_current_task), task) < 0) {
return -1;
}
return 1;
return 0;
}

static int
leave_task(asyncio_state *state, PyObject *loop, PyObject *task)
/*[clinic end generated code: output=0ebf6db4b858fb41 input=51296a46313d1ad8]*/
{
int res = _PyDict_DelItemIf(state->current_tasks, loop,
leave_task_predicate, task);
if (res == 0) {
// task was not found
return err_leave_task(Py_None, task);
PyObject *item;
if (PyObject_GetOptionalAttr(loop, &_Py_ID(_current_task), &item) < 0) {
return -1;
}
return res;
}

static PyObject *
swap_current_task_lock_held(PyDictObject *current_tasks, PyObject *loop,
Py_hash_t hash, PyObject *task)
{
PyObject *prev_task;
if (_PyDict_GetItemRef_KnownHash_LockHeld(current_tasks, loop, hash, &prev_task) < 0) {
return NULL;
if (item == NULL || item == Py_None) {
// current task is not set
PyErr_Format(PyExc_RuntimeError,
"Leaving task %R does not match the current task %R.",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having "None" in the message could be a bit confusing. Perhaps something like "Leaving task %R does not match the current task (None)" or "Unexpected leaving task %R"?

task, Py_None);
return -1;
}
if (_PyDict_SetItem_KnownHash_LockHeld(current_tasks, loop, task, hash) < 0) {
Py_XDECREF(prev_task);
return NULL;

if (item != task) {
// different task
PyErr_Format(PyExc_RuntimeError,
"Leaving task %R does not match the current task %R.",
task, item);
Py_DECREF(item);
return -1;
}
if (prev_task == NULL) {
Py_RETURN_NONE;
Py_DECREF(item);

if (PyObject_SetAttr(loop, &_Py_ID(_current_task), Py_None) < 0) {
return -1;
}
return prev_task;
return 0;
}


static PyObject *
swap_current_task(asyncio_state *state, PyObject *loop, PyObject *task)
{
PyObject *prev_task;

if (task == Py_None) {
if (PyDict_Pop(state->current_tasks, loop, &prev_task) < 0) {
return NULL;
}
if (prev_task == NULL) {
Py_RETURN_NONE;
}
return prev_task;
if (PyObject_GetOptionalAttr(loop, &_Py_ID(_current_task), &prev_task) < 0) {
return NULL;
}

Py_hash_t hash = PyObject_Hash(loop);
if (hash == -1) {
if (PyObject_SetAttr(loop, &_Py_ID(_current_task), task) < 0) {
Py_XDECREF(prev_task);
return NULL;
}

PyDictObject *current_tasks = (PyDictObject *)state->current_tasks;
Py_BEGIN_CRITICAL_SECTION(current_tasks);
prev_task = swap_current_task_lock_held(current_tasks, loop, hash, task);
Py_END_CRITICAL_SECTION();
if (prev_task == NULL) {
Py_RETURN_NONE;
}
return prev_task;
}

Expand Down Expand Up @@ -3503,9 +3477,6 @@ static PyObject *
_asyncio_current_task_impl(PyObject *module, PyObject *loop)
/*[clinic end generated code: output=fe15ac331a7f981a input=58910f61a5627112]*/
{
PyObject *ret;
asyncio_state *state = get_asyncio_state(module);

if (loop == Py_None) {
loop = _asyncio_get_running_loop_impl(module);
if (loop == NULL) {
Expand All @@ -3515,12 +3486,19 @@ _asyncio_current_task_impl(PyObject *module, PyObject *loop)
Py_INCREF(loop);
}

int rc = PyDict_GetItemRef(state->current_tasks, loop, &ret);
PyObject *item;
if (PyObject_GetOptionalAttr(loop, &_Py_ID(_current_task), &item) < 0) {
Py_DECREF(loop);
return NULL;
}

Py_DECREF(loop);
if (rc == 0) {

if (item == NULL) {
Py_RETURN_NONE;
}
return ret;

return item;
}


Expand Down Expand Up @@ -3675,7 +3653,6 @@ module_traverse(PyObject *mod, visitproc visit, void *arg)

Py_VISIT(state->non_asyncio_tasks);
Py_VISIT(state->eager_tasks);
Py_VISIT(state->current_tasks);
Py_VISIT(state->iscoroutine_typecache);

Py_VISIT(state->context_kwname);
Expand Down Expand Up @@ -3706,7 +3683,6 @@ module_clear(PyObject *mod)

Py_CLEAR(state->non_asyncio_tasks);
Py_CLEAR(state->eager_tasks);
Py_CLEAR(state->current_tasks);
Py_CLEAR(state->iscoroutine_typecache);

Py_CLEAR(state->context_kwname);
Expand Down Expand Up @@ -3735,10 +3711,6 @@ module_init(asyncio_state *state)
goto fail;
}

state->current_tasks = PyDict_New();
if (state->current_tasks == NULL) {
goto fail;
}

state->iscoroutine_typecache = PySet_New(NULL);
if (state->iscoroutine_typecache == NULL) {
Expand Down Expand Up @@ -3872,10 +3844,6 @@ module_exec(PyObject *mod)
return -1;
}

if (PyModule_AddObjectRef(mod, "_current_tasks", state->current_tasks) < 0) {
return -1;
}


return 0;
}
Expand Down
Loading