Skip to content

Commit e1f3946

Browse files
committed
fix: Change thread local context to allow overlapped scopes
1 parent 5d4e15f commit e1f3946

File tree

1 file changed

+179
-37
lines changed

1 file changed

+179
-37
lines changed

opentelemetry/src/context.rs

Lines changed: 179 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use crate::otel_warn;
12
#[cfg(feature = "trace")]
23
use crate::trace::context::SynchronizedSpan;
34
use std::any::{Any, TypeId};
@@ -9,7 +10,7 @@ use std::marker::PhantomData;
910
use std::sync::Arc;
1011

1112
thread_local! {
12-
static CURRENT_CONTEXT: RefCell<Context> = RefCell::new(Context::default());
13+
static CURRENT_CONTEXT: RefCell<ContextStack> = RefCell::new(ContextStack::default());
1314
}
1415

1516
/// An execution-scoped collection of values.
@@ -122,7 +123,7 @@ impl Context {
122123
/// Note: This function will panic if you attempt to attach another context
123124
/// while the current one is still borrowed.
124125
pub fn map_current<T>(f: impl FnOnce(&Context) -> T) -> T {
125-
CURRENT_CONTEXT.with(|cx| f(&cx.borrow()))
126+
CURRENT_CONTEXT.with(|cx| cx.borrow().map_current_cx(f))
126127
}
127128

128129
/// Returns a clone of the current thread's context with the given value.
@@ -298,12 +299,10 @@ impl Context {
298299
/// assert_eq!(Context::current().get::<ValueA>(), None);
299300
/// ```
300301
pub fn attach(self) -> ContextGuard {
301-
let previous_cx = CURRENT_CONTEXT
302-
.try_with(|current| current.replace(self))
303-
.ok();
302+
let cx_id = CURRENT_CONTEXT.with(|cx| cx.borrow_mut().push(self));
304303

305304
ContextGuard {
306-
previous_cx,
305+
cx_pos: cx_id,
307306
_marker: PhantomData,
308307
}
309308
}
@@ -344,17 +343,19 @@ impl fmt::Debug for Context {
344343
}
345344

346345
/// A guard that resets the current context to the prior context when dropped.
347-
#[allow(missing_debug_implementations)]
346+
#[derive(Debug)]
348347
pub struct ContextGuard {
349-
previous_cx: Option<Context>,
350-
// ensure this type is !Send as it relies on thread locals
348+
// The position of the context in the stack. This is used to pop the context.
349+
cx_pos: u16,
350+
// Ensure this type is !Send as it relies on thread locals
351351
_marker: PhantomData<*const ()>,
352352
}
353353

354354
impl Drop for ContextGuard {
355355
fn drop(&mut self) {
356-
if let Some(previous_cx) = self.previous_cx.take() {
357-
let _ = CURRENT_CONTEXT.try_with(|current| current.replace(previous_cx));
356+
let id = self.cx_pos;
357+
if id > ContextStack::BASE_POS && id < ContextStack::MAX_POS {
358+
CURRENT_CONTEXT.with(|context_stack| context_stack.borrow_mut().pop_id(id));
358359
}
359360
}
360361
}
@@ -381,17 +382,112 @@ impl Hasher for IdHasher {
381382
}
382383
}
383384

385+
/// A stack for keeping track of the [`Context`] instances that have been attached
386+
/// to a thread.
387+
///
388+
/// The stack allows for popping of contexts by position, which is used to do out
389+
/// of order dropping of [`ContextGuard`] instances. Only when the top of the
390+
/// stack is popped, the topmost [`Context`] is actually restored.
391+
///
392+
/// The stack relies on the fact that it is thread local and that the
393+
/// [`ContextGuard`] instances that are constructed using it can't be shared with
394+
/// other threads.
395+
struct ContextStack {
396+
/// This is the current [`Context`] that is active on this thread, and the top
397+
/// of the [`ContextStack`]. It is always present, and if the `stack` is empty
398+
/// it's an empty [`Context`].
399+
///
400+
/// Having this here allows for fast access to the current [`Context`].
401+
current_cx: Context,
402+
/// A `stack` of the other contexts that have been attached to the thread.
403+
stack: Vec<Option<Context>>,
404+
/// Ensure this type is !Send as it relies on thread locals
405+
_marker: PhantomData<*const ()>,
406+
}
407+
408+
impl ContextStack {
409+
const BASE_POS: u16 = 0;
410+
const MAX_POS: u16 = u16::MAX;
411+
const INITIAL_CAPACITY: usize = 8;
412+
413+
#[inline(always)]
414+
fn push(&mut self, cx: Context) -> u16 {
415+
// The next id is the length of the `stack`, plus one since we have the
416+
// top of the [`ContextStack`] as the `current_cx`.
417+
let next_id = self.stack.len() + 1;
418+
if next_id < ContextStack::MAX_POS.into() {
419+
let current_cx = std::mem::replace(&mut self.current_cx, cx);
420+
self.stack.push(Some(current_cx));
421+
next_id as u16
422+
} else {
423+
// This is an overflow, log it and ignore it.
424+
otel_warn!(
425+
name: "Context.AttachFailed",
426+
message = format!("Too many contexts. Max limit is {}", ContextStack::MAX_POS)
427+
);
428+
ContextStack::MAX_POS
429+
}
430+
}
431+
432+
#[inline(always)]
433+
fn pop_id(&mut self, pos: u16) {
434+
if pos == ContextStack::BASE_POS || pos == ContextStack::MAX_POS {
435+
// The empty context is always at the bottom of the [`ContextStack`]
436+
// and cannot be popped, and the overflow position is invalid, so do
437+
// nothing.
438+
return;
439+
}
440+
let len: u16 = self.stack.len() as u16;
441+
// Are we at the top of the [`ContextStack`]?
442+
if pos == len {
443+
// Shrink the stack if possible to clear out any out of order pops.
444+
while let Some(None) = self.stack.last() {
445+
_ = self.stack.pop();
446+
}
447+
// Restore the previous context. This will always happen since the
448+
// empty context is always at the bottom of the stack if the
449+
// [`ContextStack`] is not empty.
450+
if let Some(Some(next_cx)) = self.stack.pop() {
451+
self.current_cx = next_cx;
452+
}
453+
} else {
454+
// This is an out of order pop.
455+
if pos >= len {
456+
// This is an invalid id, ignore it.
457+
return;
458+
}
459+
// Clear out the entry at the given id.
460+
_ = self.stack[pos as usize].take();
461+
}
462+
}
463+
464+
#[inline(always)]
465+
fn map_current_cx<T>(&self, f: impl FnOnce(&Context) -> T) -> T {
466+
f(&self.current_cx)
467+
}
468+
}
469+
470+
impl Default for ContextStack {
471+
fn default() -> Self {
472+
ContextStack {
473+
current_cx: Context::default(),
474+
stack: Vec::with_capacity(ContextStack::INITIAL_CAPACITY),
475+
_marker: PhantomData,
476+
}
477+
}
478+
}
479+
384480
#[cfg(test)]
385481
mod tests {
386482
use super::*;
387483

484+
#[derive(Debug, PartialEq)]
485+
struct ValueA(u64);
486+
#[derive(Debug, PartialEq)]
487+
struct ValueB(u64);
488+
388489
#[test]
389490
fn context_immutable() {
390-
#[derive(Debug, PartialEq)]
391-
struct ValueA(u64);
392-
#[derive(Debug, PartialEq)]
393-
struct ValueB(u64);
394-
395491
// start with Current, which should be an empty context
396492
let cx = Context::current();
397493
assert_eq!(cx.get::<ValueA>(), None);
@@ -424,66 +520,56 @@ mod tests {
424520

425521
#[test]
426522
fn nested_contexts() {
427-
#[derive(Debug, PartialEq)]
428-
struct ValueA(&'static str);
429-
#[derive(Debug, PartialEq)]
430-
struct ValueB(u64);
431-
let _outer_guard = Context::new().with_value(ValueA("a")).attach();
523+
let _outer_guard = Context::new().with_value(ValueA(1)).attach();
432524

433525
// Only value `a` is set
434526
let current = Context::current();
435-
assert_eq!(current.get(), Some(&ValueA("a")));
527+
assert_eq!(current.get(), Some(&ValueA(1)));
436528
assert_eq!(current.get::<ValueB>(), None);
437529

438530
{
439531
let _inner_guard = Context::current_with_value(ValueB(42)).attach();
440532
// Both values are set in inner context
441533
let current = Context::current();
442-
assert_eq!(current.get(), Some(&ValueA("a")));
534+
assert_eq!(current.get(), Some(&ValueA(1)));
443535
assert_eq!(current.get(), Some(&ValueB(42)));
444536

445537
assert!(Context::map_current(|cx| {
446-
assert_eq!(cx.get(), Some(&ValueA("a")));
538+
assert_eq!(cx.get(), Some(&ValueA(1)));
447539
assert_eq!(cx.get(), Some(&ValueB(42)));
448540
true
449541
}));
450542
}
451543

452544
// Resets to only value `a` when inner guard is dropped
453545
let current = Context::current();
454-
assert_eq!(current.get(), Some(&ValueA("a")));
546+
assert_eq!(current.get(), Some(&ValueA(1)));
455547
assert_eq!(current.get::<ValueB>(), None);
456548

457549
assert!(Context::map_current(|cx| {
458-
assert_eq!(cx.get(), Some(&ValueA("a")));
550+
assert_eq!(cx.get(), Some(&ValueA(1)));
459551
assert_eq!(cx.get::<ValueB>(), None);
460552
true
461553
}));
462554
}
463555

464556
#[test]
465-
#[ignore = "overlapping contexts are not supported yet"]
466557
fn overlapping_contexts() {
467-
#[derive(Debug, PartialEq)]
468-
struct ValueA(&'static str);
469-
#[derive(Debug, PartialEq)]
470-
struct ValueB(u64);
471-
472-
let outer_guard = Context::new().with_value(ValueA("a")).attach();
558+
let outer_guard = Context::new().with_value(ValueA(1)).attach();
473559

474560
// Only value `a` is set
475561
let current = Context::current();
476-
assert_eq!(current.get(), Some(&ValueA("a")));
562+
assert_eq!(current.get(), Some(&ValueA(1)));
477563
assert_eq!(current.get::<ValueB>(), None);
478564

479565
let inner_guard = Context::current_with_value(ValueB(42)).attach();
480566
// Both values are set in inner context
481567
let current = Context::current();
482-
assert_eq!(current.get(), Some(&ValueA("a")));
568+
assert_eq!(current.get(), Some(&ValueA(1)));
483569
assert_eq!(current.get(), Some(&ValueB(42)));
484570

485571
assert!(Context::map_current(|cx| {
486-
assert_eq!(cx.get(), Some(&ValueA("a")));
572+
assert_eq!(cx.get(), Some(&ValueA(1)));
487573
assert_eq!(cx.get(), Some(&ValueB(42)));
488574
true
489575
}));
@@ -492,7 +578,7 @@ mod tests {
492578

493579
// `inner_guard` is still alive so both `ValueA` and `ValueB` should still be accessible
494580
let current = Context::current();
495-
assert_eq!(current.get(), Some(&ValueA("a")));
581+
assert_eq!(current.get(), Some(&ValueA(1)));
496582
assert_eq!(current.get(), Some(&ValueB(42)));
497583

498584
drop(inner_guard);
@@ -502,4 +588,60 @@ mod tests {
502588
assert_eq!(current.get::<ValueA>(), None);
503589
assert_eq!(current.get::<ValueB>(), None);
504590
}
591+
592+
#[test]
593+
fn too_many_contexts() {
594+
let mut guards: Vec<ContextGuard> = Vec::with_capacity(ContextStack::MAX_POS as usize);
595+
let stack_max_pos = ContextStack::MAX_POS as u64;
596+
// Fill the stack up until the last position
597+
for i in 1..stack_max_pos {
598+
let cx_guard = Context::current().with_value(ValueB(i)).attach();
599+
assert_eq!(Context::current().get(), Some(&ValueB(i)));
600+
assert_eq!(cx_guard.cx_pos, i as u16);
601+
guards.push(cx_guard);
602+
}
603+
// Let's overflow the stack a couple of times
604+
for _ in 0..16 {
605+
let cx_guard = Context::current().with_value(ValueA(1)).attach();
606+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
607+
assert_eq!(Context::current().get::<ValueA>(), None);
608+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
609+
guards.push(cx_guard);
610+
}
611+
// Drop the overflow contexts
612+
for _ in 0..16 {
613+
guards.pop();
614+
assert_eq!(Context::current().get::<ValueA>(), None);
615+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 1)));
616+
}
617+
// Drop one more so we can add a new one
618+
guards.pop();
619+
assert_eq!(Context::current().get::<ValueA>(), None);
620+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
621+
// Push a new context and see that it works
622+
let cx_guard = Context::current().with_value(ValueA(2)).attach();
623+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS - 1);
624+
assert_eq!(Context::current().get(), Some(&ValueA(2)));
625+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
626+
guards.push(cx_guard);
627+
// Let's overflow the stack a couple of times again
628+
for _ in 0..16 {
629+
let cx_guard = Context::current().with_value(ValueA(1)).attach();
630+
assert_eq!(cx_guard.cx_pos, ContextStack::MAX_POS);
631+
assert_eq!(Context::current().get(), Some(&ValueA(2)));
632+
assert_eq!(Context::current().get(), Some(&ValueB(stack_max_pos - 2)));
633+
guards.push(cx_guard);
634+
}
635+
}
636+
637+
#[test]
638+
fn context_stack_pop_id() {
639+
// This is to get full line coverage of the `pop_id` function.
640+
// In real life the `Drop`` implementation of `ContextGuard` ensures that
641+
// the ids are valid and inside the bounds.
642+
let mut stack = ContextStack::default();
643+
stack.pop_id(ContextStack::BASE_POS);
644+
stack.pop_id(ContextStack::MAX_POS);
645+
stack.pop_id(4711);
646+
}
505647
}

0 commit comments

Comments
 (0)