Skip to content

gh-126907: Use a list for atexit callbacks #127935

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions Include/internal/pycore_atexit.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,29 @@ typedef struct atexit_callback {
struct atexit_callback *next;
} atexit_callback;

typedef struct {
PyObject *func;
PyObject *args;
PyObject *kwargs;
} atexit_py_callback;

struct atexit_state {
#ifdef Py_GIL_DISABLED
PyMutex ll_callbacks_lock;
#endif
atexit_callback *ll_callbacks;

// XXX The rest of the state could be moved to the atexit module state
// and a low-level callback added for it during module exec.
// For the moment we leave it here.
atexit_py_callback **callbacks;
int ncallbacks;
int callback_len;

// List containing tuples with callback information.
// e.g. [(func, args, kwargs), ...]
PyObject *callbacks;
};

#ifdef Py_GIL_DISABLED
# define _PyAtExit_LockCallbacks(state) PyMutex_Lock(&state->ll_callbacks_lock);
# define _PyAtExit_UnlockCallbacks(state) PyMutex_Unlock(&state->ll_callbacks_lock);
#else
# define _PyAtExit_LockCallbacks(state)
# define _PyAtExit_UnlockCallbacks(state)
#endif

// Export for '_interpchannels' shared extension
PyAPI_FUNC(int) _Py_AtExit(
PyInterpreterState *interp,
Expand Down
35 changes: 34 additions & 1 deletion Lib/test/test_atexit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import unittest
from test import support
from test.support import script_helper

from test.support import threading_helper

class GeneralTest(unittest.TestCase):
def test_general(self):
Expand Down Expand Up @@ -46,6 +46,39 @@ def test_atexit_instances(self):
self.assertEqual(res.out.decode().splitlines(), ["atexit2", "atexit1"])
self.assertFalse(res.err)

@threading_helper.requires_working_threading()
@support.requires_resource("cpu")
@unittest.skipUnless(support.Py_GIL_DISABLED, "only meaningful without the GIL")
def test_atexit_thread_safety(self):
# GH-126907: atexit was not thread safe on the free-threaded build
source = """
from threading import Thread

def dummy():
pass


def thready():
for _ in range(100):
atexit.register(dummy)
atexit._clear()
atexit.register(dummy)
atexit.unregister(dummy)
atexit._run_exitfuncs()


threads = [Thread(target=thready) for _ in range(10)]
for thread in threads:
thread.start()

for thread in threads:
thread.join()
"""

# atexit._clear() has some evil side effects, and we don't
# want them to affect the rest of the tests.
script_helper.assert_python_ok("-c", textwrap.dedent(source))


@support.cpython_only
class SubinterpreterTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Fix crash when using :mod:`atexit` concurrently on the :term:`free-threaded
<free threading>` build.
153 changes: 75 additions & 78 deletions Modules/atexitmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ PyUnstable_AtExit(PyInterpreterState *interp,
callback->next = NULL;

struct atexit_state *state = &interp->atexit;
_PyAtExit_LockCallbacks(state);
atexit_callback *top = state->ll_callbacks;
if (top == NULL) {
state->ll_callbacks = callback;
Expand All @@ -49,36 +50,16 @@ PyUnstable_AtExit(PyInterpreterState *interp,
callback->next = top;
state->ll_callbacks = callback;
}
_PyAtExit_UnlockCallbacks(state);
return 0;
}


static void
atexit_delete_cb(struct atexit_state *state, int i)
{
atexit_py_callback *cb = state->callbacks[i];
state->callbacks[i] = NULL;

Py_DECREF(cb->func);
Py_DECREF(cb->args);
Py_XDECREF(cb->kwargs);
PyMem_Free(cb);
}


/* Clear all callbacks without calling them */
static void
atexit_cleanup(struct atexit_state *state)
{
atexit_py_callback *cb;
for (int i = 0; i < state->ncallbacks; i++) {
cb = state->callbacks[i];
if (cb == NULL)
continue;

atexit_delete_cb(state, i);
}
state->ncallbacks = 0;
PyList_Clear(state->callbacks);
}


Expand All @@ -89,23 +70,21 @@ _PyAtExit_Init(PyInterpreterState *interp)
// _PyAtExit_Init() must only be called once
assert(state->callbacks == NULL);

state->callback_len = 32;
state->ncallbacks = 0;
state->callbacks = PyMem_New(atexit_py_callback*, state->callback_len);
state->callbacks = PyList_New(0);
if (state->callbacks == NULL) {
return _PyStatus_NO_MEMORY();
}
return _PyStatus_OK();
}


void
_PyAtExit_Fini(PyInterpreterState *interp)
{
// In theory, there shouldn't be any threads left by now, so we
// won't lock this.
struct atexit_state *state = &interp->atexit;
atexit_cleanup(state);
PyMem_Free(state->callbacks);
state->callbacks = NULL;
Py_CLEAR(state->callbacks);

atexit_callback *next = state->ll_callbacks;
state->ll_callbacks = NULL;
Expand All @@ -120,35 +99,44 @@ _PyAtExit_Fini(PyInterpreterState *interp)
}
}


static void
atexit_callfuncs(struct atexit_state *state)
{
assert(!PyErr_Occurred());
assert(state->callbacks != NULL);
assert(PyList_CheckExact(state->callbacks));

if (state->ncallbacks == 0) {
// Create a copy of the list for thread safety
PyObject *copy = PyList_GetSlice(state->callbacks, 0, PyList_GET_SIZE(state->callbacks));
if (copy == NULL)
{
PyErr_WriteUnraisable(NULL);
return;
}

for (int i = state->ncallbacks - 1; i >= 0; i--) {
atexit_py_callback *cb = state->callbacks[i];
if (cb == NULL) {
continue;
}
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(copy); ++i) {
// We don't have to worry about evil borrowed references, because
// no other threads can access this list.
PyObject *tuple = PyList_GET_ITEM(copy, i);
assert(PyTuple_CheckExact(tuple));

PyObject *func = PyTuple_GET_ITEM(tuple, 0);
PyObject *args = PyTuple_GET_ITEM(tuple, 1);
PyObject *kwargs = PyTuple_GET_ITEM(tuple, 2);

// bpo-46025: Increment the refcount of cb->func as the call itself may unregister it
PyObject* the_func = Py_NewRef(cb->func);
PyObject *res = PyObject_Call(cb->func, cb->args, cb->kwargs);
PyObject *res = PyObject_Call(func,
args,
kwargs == Py_None ? NULL : kwargs);
if (res == NULL) {
PyErr_FormatUnraisable(
"Exception ignored in atexit callback %R", the_func);
"Exception ignored in atexit callback %R", func);
}
else {
Py_DECREF(res);
}
Py_DECREF(the_func);
}

Py_DECREF(copy);
atexit_cleanup(state);

assert(!PyErr_Occurred());
Expand Down Expand Up @@ -194,33 +182,27 @@ atexit_register(PyObject *module, PyObject *args, PyObject *kwargs)
"the first argument must be callable");
return NULL;
}
PyObject *func_args = PyTuple_GetSlice(args, 1, PyTuple_GET_SIZE(args));
PyObject *func_kwargs = kwargs;

struct atexit_state *state = get_atexit_state();
if (state->ncallbacks >= state->callback_len) {
atexit_py_callback **r;
state->callback_len += 16;
size_t size = sizeof(atexit_py_callback*) * (size_t)state->callback_len;
r = (atexit_py_callback**)PyMem_Realloc(state->callbacks, size);
if (r == NULL) {
return PyErr_NoMemory();
}
state->callbacks = r;
if (func_kwargs == NULL)
{
func_kwargs = Py_None;
}

atexit_py_callback *callback = PyMem_Malloc(sizeof(atexit_py_callback));
if (callback == NULL) {
return PyErr_NoMemory();
PyObject *callback = PyTuple_Pack(3, func, func_args, func_kwargs);
if (callback == NULL)
{
return NULL;
}

callback->args = PyTuple_GetSlice(args, 1, PyTuple_GET_SIZE(args));
if (callback->args == NULL) {
PyMem_Free(callback);
struct atexit_state *state = get_atexit_state();
// atexit callbacks go in a LIFO order
if (PyList_Insert(state->callbacks, 0, callback) < 0)
{
Py_DECREF(callback);
return NULL;
}
callback->func = Py_NewRef(func);
callback->kwargs = Py_XNewRef(kwargs);

state->callbacks[state->ncallbacks++] = callback;
Py_DECREF(callback);

return Py_NewRef(func);
}
Expand Down Expand Up @@ -264,7 +246,33 @@ static PyObject *
atexit_ncallbacks(PyObject *module, PyObject *unused)
{
struct atexit_state *state = get_atexit_state();
return PyLong_FromSsize_t(state->ncallbacks);
assert(state->callbacks != NULL);
assert(PyList_CheckExact(state->callbacks));
return PyLong_FromSsize_t(PyList_GET_SIZE(state->callbacks));
}

static int
atexit_unregister_locked(PyObject *callbacks, PyObject *func)
{
for (Py_ssize_t i = 0; i < PyList_GET_SIZE(callbacks); ++i) {
PyObject *tuple = PyList_GET_ITEM(callbacks, i);
assert(PyTuple_CheckExact(tuple));
PyObject *to_compare = PyTuple_GET_ITEM(tuple, 0);
int cmp = PyObject_RichCompareBool(func, to_compare, Py_EQ);
if (cmp < 0)
{
return -1;
}
if (cmp == 1) {
// We found a callback!
if (PyList_SetSlice(callbacks, i, i + 1, NULL) < 0) {
return -1;
}
--i;
}
}

return 0;
}

PyDoc_STRVAR(atexit_unregister__doc__,
Expand All @@ -280,22 +288,11 @@ static PyObject *
atexit_unregister(PyObject *module, PyObject *func)
{
struct atexit_state *state = get_atexit_state();
for (int i = 0; i < state->ncallbacks; i++)
{
atexit_py_callback *cb = state->callbacks[i];
if (cb == NULL) {
continue;
}

int eq = PyObject_RichCompareBool(cb->func, func, Py_EQ);
if (eq < 0) {
return NULL;
}
if (eq) {
atexit_delete_cb(state, i);
}
}
Py_RETURN_NONE;
int result;
Py_BEGIN_CRITICAL_SECTION(state->callbacks);
result = atexit_unregister_locked(state->callbacks, func);
Py_END_CRITICAL_SECTION();
return result < 0 ? NULL : Py_None;
}


Expand Down
Loading