Skip to content

Commit 0a9a491

Browse files
committed
fix: Change thread local context to allow overlapped scopes
1 parent 5d4e15f commit 0a9a491

File tree

1 file changed

+164
-21
lines changed

1 file changed

+164
-21
lines changed

opentelemetry/src/context.rs

Lines changed: 164 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::marker::PhantomData;
99
use std::sync::Arc;
1010

1111
thread_local! {
12-
static CURRENT_CONTEXT: RefCell<Context> = RefCell::new(Context::default());
12+
static CURRENT_CONTEXT: RefCell<ContextStack> = RefCell::new(ContextStack::default());
1313
}
1414

1515
/// An execution-scoped collection of values.
@@ -122,7 +122,7 @@ impl Context {
122122
/// Note: This function will panic if you attempt to attach another context
123123
/// while the current one is still borrowed.
124124
pub fn map_current<T>(f: impl FnOnce(&Context) -> T) -> T {
125-
CURRENT_CONTEXT.with(|cx| f(&cx.borrow()))
125+
CURRENT_CONTEXT.with(|cx| cx.borrow().map_current_cx(f))
126126
}
127127

128128
/// Returns a clone of the current thread's context with the given value.
@@ -298,12 +298,10 @@ impl Context {
298298
/// assert_eq!(Context::current().get::<ValueA>(), None);
299299
/// ```
300300
pub fn attach(self) -> ContextGuard {
301-
let previous_cx = CURRENT_CONTEXT
302-
.try_with(|current| current.replace(self))
303-
.ok();
301+
let cx_id = CURRENT_CONTEXT.with(|cx| cx.borrow_mut().push(self));
304302

305303
ContextGuard {
306-
previous_cx,
304+
cx_pos: cx_id,
307305
_marker: PhantomData,
308306
}
309307
}
@@ -344,17 +342,19 @@ impl fmt::Debug for Context {
344342
}
345343

346344
/// A guard that resets the current context to the prior context when dropped.
347-
#[allow(missing_debug_implementations)]
345+
#[derive(Debug)]
348346
pub struct ContextGuard {
349-
previous_cx: Option<Context>,
350-
// ensure this type is !Send as it relies on thread locals
347+
// The position of the context in the stack. This is used to pop the context.
348+
cx_pos: u16,
349+
// Ensure this type is !Send as it relies on thread locals
351350
_marker: PhantomData<*const ()>,
352351
}
353352

354353
impl Drop for ContextGuard {
355354
fn drop(&mut self) {
356-
if let Some(previous_cx) = self.previous_cx.take() {
357-
let _ = CURRENT_CONTEXT.try_with(|current| current.replace(previous_cx));
355+
let id = self.cx_pos;
356+
if id > ContextStack::BASE_POS && id < ContextStack::MAX_POS {
357+
CURRENT_CONTEXT.with(|context_stack| context_stack.borrow_mut().pop_id(id));
358358
}
359359
}
360360
}
@@ -381,10 +381,107 @@ impl Hasher for IdHasher {
381381
}
382382
}
383383

384+
/// A stack for keeping track of the [`Context`] instances that have been attached
385+
/// to a thread.
386+
///
387+
/// The stack allows for popping of contexts by position, which is used to do out
388+
/// of order dropping of [`ContextGuard`] instances. Only when the top of the
389+
/// stack is popped, the topmost [`Context`] is actually restored.
390+
///
391+
/// The stack relies on the fact that it is thread local and that the
392+
/// [`ContextGuard`] instances that are constructed using it can't be shared with
393+
/// other threads.
394+
struct ContextStack {
395+
/// This is the current [`Context`] that is active on this thread, and the top
396+
/// of the [`ContextStack`]. It is always present, and if the `stack` is empty
397+
/// it's an empty [`Context`].
398+
///
399+
/// Having this here allows for fast access to the current [`Context`].
400+
current_cx: Context,
401+
/// A `stack` of the other contexts that have been attached to the thread.
402+
stack: Vec<Option<Context>>,
403+
/// Ensure this type is !Send as it relies on thread locals
404+
_marker: PhantomData<*const ()>,
405+
}
406+
407+
impl ContextStack {
408+
const BASE_POS: u16 = 0;
409+
const MAX_POS: u16 = u16::MAX;
410+
const INITIAL_CAPACITY: usize = 8;
411+
412+
#[inline(always)]
413+
fn push(&mut self, cx: Context) -> u16 {
414+
// The next id is the length of the `stack`, plus one since we have the
415+
// top of the [`ContextStack`] as the `current_cx`.
416+
let next_id = self.stack.len() + 1;
417+
if next_id < ContextStack::MAX_POS.into() {
418+
let current_cx = std::mem::replace(&mut self.current_cx, cx);
419+
self.stack.push(Some(current_cx));
420+
next_id as u16
421+
} else {
422+
// This is an overflow, log it and ignore it.
423+
// TODO:ban add logging here
424+
ContextStack::MAX_POS
425+
}
426+
}
427+
428+
#[inline(always)]
429+
fn pop_id(&mut self, pos: u16) {
430+
if pos == ContextStack::BASE_POS || pos == ContextStack::MAX_POS {
431+
// The empty context is always at the bottom of the [`ContextStack`]
432+
// and cannot be popped, and the overflow position is invalid, so do
433+
// nothing.
434+
return;
435+
}
436+
let len: u16 = self.stack.len() as u16;
437+
// Are we at the top of the [`ContextStack`]?
438+
if pos == len {
439+
// Shrink the stack if possible to clear out any out of order pops.
440+
while let Some(None) = self.stack.last() {
441+
_ = self.stack.pop();
442+
}
443+
// Restore the previous context. This will always happen since the
444+
// empty context is always at the bottom of the stack if the
445+
// [`ContextStack`] is not empty.
446+
if let Some(Some(next_cx)) = self.stack.pop() {
447+
self.current_cx = next_cx;
448+
}
449+
} else {
450+
// This is an out of order pop.
451+
if pos >= len {
452+
// This is an invalid id, ignore it.
453+
return;
454+
}
455+
// Clear out the entry at the given id.
456+
_ = self.stack[pos as usize].take();
457+
}
458+
}
459+
460+
#[inline(always)]
461+
fn map_current_cx<T>(&self, f: impl FnOnce(&Context) -> T) -> T {
462+
f(&self.current_cx)
463+
}
464+
}
465+
466+
impl Default for ContextStack {
467+
fn default() -> Self {
468+
ContextStack {
469+
current_cx: Context::default(),
470+
stack: Vec::with_capacity(ContextStack::INITIAL_CAPACITY),
471+
_marker: PhantomData,
472+
}
473+
}
474+
}
475+
384476
#[cfg(test)]
385477
mod tests {
386478
use super::*;
387479

480+
#[derive(Debug, PartialEq)]
481+
struct ValueA(&'static str);
482+
#[derive(Debug, PartialEq)]
483+
struct ValueB(u64);
484+
388485
#[test]
389486
fn context_immutable() {
390487
#[derive(Debug, PartialEq)]
@@ -424,10 +521,6 @@ mod tests {
424521

425522
#[test]
426523
fn nested_contexts() {
427-
#[derive(Debug, PartialEq)]
428-
struct ValueA(&'static str);
429-
#[derive(Debug, PartialEq)]
430-
struct ValueB(u64);
431524
let _outer_guard = Context::new().with_value(ValueA("a")).attach();
432525

433526
// Only value `a` is set
@@ -462,13 +555,7 @@ mod tests {
462555
}
463556

464557
#[test]
465-
#[ignore = "overlapping contexts are not supported yet"]
466558
fn overlapping_contexts() {
467-
#[derive(Debug, PartialEq)]
468-
struct ValueA(&'static str);
469-
#[derive(Debug, PartialEq)]
470-
struct ValueB(u64);
471-
472559
let outer_guard = Context::new().with_value(ValueA("a")).attach();
473560

474561
// Only value `a` is set
@@ -502,4 +589,60 @@ mod tests {
502589
assert_eq!(current.get::<ValueA>(), None);
503590
assert_eq!(current.get::<ValueB>(), None);
504591
}
592+
593+
#[test]
594+
fn too_many_contexts() {
595+
let mut guards: Vec<ContextGuard> = Vec::with_capacity(ContextStack::MAX_POS as usize);
596+
let stack_max_pos = ContextStack::MAX_POS as u64;
597+
// Fill the stack up until the last position
598+
for i in 1..stack_max_pos {
599+
let cx_guard = Context::current().with_value(ValueB(i)).attach();
600+
assert_eq!(Context::current().get(), Some(&ValueB(i)));
601+
assert_eq!(cx_guard.cx_pos, i as u16);
602+
guards.push(cx_guard);
603+
}
604+
// Let's overflow the stack a couple of times
605+
for _ in 0..16 {
606+
let cx_guard = Context::current().with_value(ValueA("overflow")).attach();
607+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
608+
assert_eq!(Context::current().get::<ValueA>(), None);
609+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
610+
guards.push(cx_guard);
611+
}
612+
// Drop the overflow contexts
613+
for _ in 0..16 {
614+
guards.pop();
615+
assert_eq!(Context::current().get::<ValueA>(), None);
616+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
617+
}
618+
// Drop one more so we can add a new one
619+
guards.pop();
620+
assert_eq!(Context::current().get::<ValueA>(), None);
621+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
622+
// Push a new context and see that it works
623+
let cx_guard = Context::current().with_value(ValueA("last")).attach();
624+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS - 1);
625+
assert_eq!(Context::current().get(), Some(&ValueA("last")));
626+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
627+
guards.push(cx_guard);
628+
// Let's overflow the stack a couple of times again
629+
for _ in 0..16 {
630+
let cx_guard = Context::current().with_value(ValueA("overflow")).attach();
631+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
632+
assert_eq!(Context::current().get(), Some(&ValueA("last")));
633+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
634+
guards.push(cx_guard);
635+
}
636+
}
637+
638+
#[test]
639+
fn context_stack_pop_id() {
640+
// This is to get full line coverage of the `pop_id` function.
641+
// In real life the `Drop`` implementation of `ContextGuard` ensures that
642+
// the ids are valid and inside the bounds.
643+
let mut stack = ContextStack::default();
644+
stack.pop_id(ContextStack::BASE_POS);
645+
stack.pop_id(ContextStack::MAX_POS);
646+
stack.pop_id(4711);
647+
}
505648
}

0 commit comments

Comments
 (0)