Skip to content

Commit 2b243ef

Browse files
authored
Merge pull request #62 from iksteen/task-locals-arc
Don't attach to runtime when cloning TaskLocals.
2 parents 5ce66d1 + b921141 commit 2b243ef

File tree

5 files changed

+50
-46
lines changed

5 files changed

+50
-46
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@ To see unreleased changes, please see the CHANGELOG on the main branch.
1010

1111
<!-- towncrier release notes start -->
1212

13+
- Avoid attaching to the runtime when cloning TaskLocals by using std::sync::Arc. [#62](https://github.com/PyO3/pyo3-async-runtimes/pull/62)
14+
15+
## [0.26.0] - 2025-09-02
16+
17+
- Bump to pyo3 0.26. [#54](https://github.com/PyO3/pyo3-async-runtimes/pull/54)
18+
1319
## [0.25.0] - 2025-05-14
1420

1521
- Bump to pyo3 0.25. [#41](https://github.com/PyO3/pyo3-async-runtimes/pull/41)

src/async_std.rs

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,7 @@ impl ContextExt for AsyncStdRuntime {
9191

9292
fn get_task_locals() -> Option<TaskLocals> {
9393
TASK_LOCALS
94-
.try_with(|c| {
95-
c.borrow()
96-
.as_ref()
97-
.map(|locals| Python::attach(|py| locals.clone_ref(py)))
98-
})
94+
.try_with(|c| c.borrow().as_ref().map(|locals| locals.clone()))
9995
.unwrap_or_default()
10096
}
10197
}

src/generic.rs

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ where
9191
R: ContextExt,
9292
{
9393
if let Some(locals) = R::get_task_locals() {
94-
Ok(locals.event_loop.into_bound(py))
94+
Ok(locals.0.event_loop.clone_ref(py).into_bound(py))
9595
} else {
9696
get_running_loop(py)
9797
}
@@ -585,7 +585,7 @@ where
585585
{
586586
let (cancel_tx, cancel_rx) = oneshot::channel();
587587

588-
let py_fut = create_future(locals.event_loop.bind(py).clone())?;
588+
let py_fut = create_future(locals.0.event_loop.bind(py).clone())?;
589589
py_fut.call_method1(
590590
"add_done_callback",
591591
(PyDoneCallback {
@@ -597,11 +597,11 @@ where
597597
let future_tx2 = future_tx1.clone_ref(py);
598598

599599
R::spawn(async move {
600-
let locals2 = Python::attach(|py| locals.clone_ref(py));
600+
let locals2 = locals.clone();
601601

602602
if let Err(e) = R::spawn(async move {
603603
let result = R::scope(
604-
Python::attach(|py| locals2.clone_ref(py)),
604+
locals2.clone(),
605605
Cancellable::new_with_cancel_rx(fut, cancel_rx),
606606
)
607607
.await;
@@ -638,7 +638,7 @@ where
638638
get_panic_message(&e.into_panic())
639639
);
640640
let _ = set_result(
641-
locals.event_loop.bind(py),
641+
locals.0.event_loop.bind(py),
642642
future_tx2.bind(py),
643643
Err(RustPanic::new_err(panic_message)),
644644
)
@@ -990,7 +990,7 @@ where
990990
{
991991
let (cancel_tx, cancel_rx) = oneshot::channel();
992992

993-
let py_fut = create_future(locals.event_loop.clone_ref(py).into_bound(py))?;
993+
let py_fut = create_future(locals.0.event_loop.clone_ref(py).into_bound(py))?;
994994
py_fut.call_method1(
995995
"add_done_callback",
996996
(PyDoneCallback {
@@ -1002,11 +1002,11 @@ where
10021002
let future_tx2 = future_tx1.clone_ref(py);
10031003

10041004
R::spawn_local(async move {
1005-
let locals2 = Python::attach(|py| locals.clone_ref(py));
1005+
let locals2 = locals.clone();
10061006

10071007
if let Err(e) = R::spawn_local(async move {
10081008
let result = R::scope_local(
1009-
Python::attach(|py| locals2.clone_ref(py)),
1009+
locals2.clone(),
10101010
Cancellable::new_with_cancel_rx(fut, cancel_rx),
10111011
)
10121012
.await;
@@ -1020,7 +1020,7 @@ where
10201020
}
10211021

10221022
let _ = set_result(
1023-
locals2.event_loop.bind(py),
1023+
locals2.0.event_loop.bind(py),
10241024
future_tx1.bind(py),
10251025
result.and_then(|val| val.into_py_any(py)),
10261026
)
@@ -1043,7 +1043,7 @@ where
10431043
get_panic_message(&e.into_panic())
10441044
);
10451045
let _ = set_result(
1046-
locals.event_loop.bind(py),
1046+
locals.0.event_loop.bind(py),
10471047
future_tx2.bind(py),
10481048
Err(RustPanic::new_err(panic_message)),
10491049
)
@@ -1506,12 +1506,7 @@ struct SenderGlue {
15061506
#[pymethods]
15071507
impl SenderGlue {
15081508
pub fn send(&mut self, item: Py<PyAny>) -> PyResult<Py<PyAny>> {
1509-
Python::attach(|py| {
1510-
self.tx
1511-
.lock()
1512-
.unwrap()
1513-
.send(py, self.locals.clone_ref(py), item)
1514-
})
1509+
Python::attach(|py| self.tx.lock().unwrap().send(py, self.locals.clone(), item))
15151510
}
15161511
pub fn close(&mut self) -> PyResult<()> {
15171512
self.tx.lock().unwrap().close()

src/lib.rs

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@
9393
//! let locals = pyo3_async_runtimes::TaskLocals::with_running_loop(py)?.copy_context(py)?;
9494
//!
9595
//! // Convert the async move { } block to a Python awaitable
96-
//! pyo3_async_runtimes::tokio::future_into_py_with_locals(py, locals.clone_ref(py), async move {
96+
//! pyo3_async_runtimes::tokio::future_into_py_with_locals(py, locals.clone(), async move {
9797
//! let py_sleep = Python::attach(|py| {
9898
//! // Sometimes we need to call other async Python functions within
9999
//! // this future. In order for this to work, we need to track the
@@ -162,9 +162,9 @@
162162
//!
163163
//! pyo3_async_runtimes::tokio::future_into_py_with_locals(
164164
//! py,
165-
//! locals.clone_ref(py),
165+
//! locals.clone(),
166166
//! // Store the current locals in task-local data
167-
//! pyo3_async_runtimes::tokio::scope(locals.clone_ref(py), async move {
167+
//! pyo3_async_runtimes::tokio::scope(locals.clone(), async move {
168168
//! let py_sleep = Python::attach(|py| {
169169
//! pyo3_async_runtimes::into_future_with_locals(
170170
//! // Now we can get the current locals through task-local data
@@ -189,9 +189,9 @@
189189
//!
190190
//! pyo3_async_runtimes::tokio::future_into_py_with_locals(
191191
//! py,
192-
//! locals.clone_ref(py),
192+
//! locals.clone(),
193193
//! // Store the current locals in task-local data
194-
//! pyo3_async_runtimes::tokio::scope(locals.clone_ref(py), async move {
194+
//! pyo3_async_runtimes::tokio::scope(locals.clone(), async move {
195195
//! let py_sleep = Python::attach(|py| {
196196
//! pyo3_async_runtimes::into_future_with_locals(
197197
//! &pyo3_async_runtimes::tokio::get_current_locals(py)?,
@@ -395,6 +395,7 @@ pub mod doc_test {
395395
}
396396

397397
use std::future::Future;
398+
use std::sync::Arc;
398399

399400
use futures::channel::oneshot;
400401
use pyo3::{call::PyCallArgs, prelude::*, sync::PyOnceLock, types::PyDict};
@@ -468,22 +469,26 @@ fn copy_context(py: Python) -> PyResult<Bound<PyAny>> {
468469
contextvars(py)?.call_method0("copy_context")
469470
}
470471

471-
/// Task-local data to store for Python conversions.
472+
/// Task-local inner structure.
472473
#[derive(Debug)]
473-
pub struct TaskLocals {
474+
struct TaskLocalsInner {
474475
/// Track the event loop of the Python task
475476
event_loop: Py<PyAny>,
476477
/// Track the contextvars of the Python task
477478
context: Py<PyAny>,
478479
}
479480

481+
/// Task-local data to store for Python conversions.
482+
#[derive(Debug)]
483+
pub struct TaskLocals(Arc<TaskLocalsInner>);
484+
480485
impl TaskLocals {
481486
/// At a minimum, TaskLocals must store the event loop.
482487
pub fn new(event_loop: Bound<PyAny>) -> Self {
483-
Self {
488+
Self(Arc::new(TaskLocalsInner {
484489
context: event_loop.py().None(),
485490
event_loop: event_loop.into(),
486-
}
491+
}))
487492
}
488493

489494
/// Construct TaskLocals with the event loop returned by `get_running_loop`
@@ -493,10 +498,10 @@ impl TaskLocals {
493498

494499
/// Manually provide the contextvars for the current task.
495500
pub fn with_context(self, context: Bound<PyAny>) -> Self {
496-
Self {
501+
Self(Arc::new(TaskLocalsInner {
502+
event_loop: self.0.event_loop.clone_ref(context.py()),
497503
context: context.into(),
498-
..self
499-
}
504+
}))
500505
}
501506

502507
/// Capture the current task's contextvars
@@ -506,21 +511,26 @@ impl TaskLocals {
506511

507512
/// Get a reference to the event loop
508513
pub fn event_loop<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> {
509-
self.event_loop.clone_ref(py).into_bound(py)
514+
self.0.event_loop.clone_ref(py).into_bound(py)
510515
}
511516

512517
/// Get a reference to the python context
513518
pub fn context<'p>(&self, py: Python<'p>) -> Bound<'p, PyAny> {
514-
self.context.clone_ref(py).into_bound(py)
519+
self.0.context.clone_ref(py).into_bound(py)
515520
}
516521

517-
/// Create a clone of the TaskLocals by incrementing the reference counters of the event loop and
518-
/// contextvars.
519-
pub fn clone_ref(&self, py: Python<'_>) -> Self {
520-
Self {
521-
event_loop: self.event_loop.clone_ref(py),
522-
context: self.context.clone_ref(py),
523-
}
522+
/// Create a clone of the TaskLocals. No longer uses the runtime, use `clone` instead.
523+
#[deprecated(note = "please use `clone` instead")]
524+
pub fn clone_ref(&self, _py: Python<'_>) -> Self {
525+
self.clone()
526+
}
527+
}
528+
529+
impl Clone for TaskLocals {
530+
/// Create a clone of the TaskLocals by incrementing the reference counter of the inner
531+
/// structure.
532+
fn clone(&self) -> Self {
533+
Self(self.0.clone())
524534
}
525535
}
526536

src/tokio.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,7 @@ impl ContextExt for TokioRuntime {
109109

110110
fn get_task_locals() -> Option<TaskLocals> {
111111
TASK_LOCALS
112-
.try_with(|c| {
113-
c.get()
114-
.map(|locals| Python::attach(|py| locals.clone_ref(py)))
115-
})
112+
.try_with(|c| c.get().map(|locals| locals.clone()))
116113
.unwrap_or_default()
117114
}
118115
}

0 commit comments

Comments
 (0)