Skip to content

Commit 2fc1298

Browse files
committed
Make set_print return the previous callback
1 parent 4c95542 commit 2fc1298

File tree

4 files changed

+63
-31
lines changed

4 files changed

+63
-31
lines changed

libbpf-rs/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ vsprintf = "2.0"
3131
libc = "0.2"
3232
plain = "0.2.3"
3333
scopeguard = "1.1"
34+
serial_test = "0.5"

libbpf-rs/src/object.rs

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -29,30 +29,14 @@ impl ObjectBuilder {
2929
}
3030

3131
/// Option to print debug output to stderr.
32+
///
33+
/// Note: This function interacts poorly with [`set_print`].
3234
pub fn debug(&mut self, dbg: bool) -> &mut Self {
33-
extern "C" fn cb(
34-
_level: libbpf_sys::libbpf_print_level,
35-
fmtstr: *const c_char,
36-
va_list: *mut libbpf_sys::__va_list_tag,
37-
) -> i32 {
38-
match unsafe { vsprintf::vsprintf(fmtstr, va_list) } {
39-
Ok(s) => {
40-
print!("{}", s);
41-
0
42-
}
43-
Err(e) => {
44-
eprintln!("Failed to parse libbpf output: {}", e);
45-
1
46-
}
47-
}
48-
}
49-
5035
if dbg {
51-
unsafe { libbpf_sys::libbpf_set_print(Some(cb)) };
36+
set_print(|_, s| print!("{}", s));
5237
} else {
53-
unsafe { libbpf_sys::libbpf_set_print(None) };
38+
set_print(|_, _| ());
5439
}
55-
5640
self
5741
}
5842

libbpf-rs/src/print.rs

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::*;
2+
use std::io::{self, Write};
23
use std::os::raw::c_char;
34
use std::sync::atomic::{AtomicPtr, Ordering};
45

@@ -24,8 +25,29 @@ impl From<libbpf_sys::libbpf_print_level> for PrintLevel {
2425

2526
pub type PrintCallback = fn(PrintLevel, &str);
2627

28+
/// Mimic the default print functionality of libbpf. This way if the user calls `set_print` when no
29+
/// previous callback had been set and then tries to restore it, it will appear to work correctly.
30+
fn default_callback(lvl: PrintLevel, msg: &str) {
31+
if lvl == PrintLevel::Debug {
32+
return;
33+
}
34+
35+
let _ = io::stderr().write(msg.as_bytes());
36+
}
37+
38+
// not allowed to use function pointers in const functions, so it needs to be a macro
39+
macro_rules! to_ptr {
40+
($x:expr) => {
41+
unsafe { std::mem::transmute::<PrintCallback, *mut ()>($x) }
42+
};
43+
}
44+
45+
fn from_ptr(ptr: *mut ()) -> PrintCallback {
46+
unsafe { std::mem::transmute::<*mut (), PrintCallback>(ptr) }
47+
}
48+
2749
// There is no AtomicFnPtr. This is a workaround.
28-
static PRINT_CB: AtomicPtr<()> = AtomicPtr::new(std::ptr::null_mut());
50+
static PRINT_CB: AtomicPtr<()> = AtomicPtr::new(to_ptr!(default_callback));
2951

3052
/// libbpf's default cb uses vfprintf's return code...which is ignored everywhere. Mimic that for
3153
/// completeness
@@ -34,27 +56,33 @@ extern "C" fn outer_print_cb(
3456
fmtstr: *const c_char,
3557
va_list: *mut libbpf_sys::__va_list_tag,
3658
) -> i32 {
37-
// can't be null: we cant't get here until print_cb has been initialized with a function
38-
// pointer.
39-
let cb = PRINT_CB.load(Ordering::Relaxed);
40-
let cb = unsafe { std::mem::transmute::<*const (), PrintCallback>(cb) };
59+
let cb = from_ptr(PRINT_CB.load(Ordering::Relaxed));
4160
match unsafe { vsprintf::vsprintf(fmtstr, va_list) } {
4261
Ok(s) => {
4362
cb(level.into(), &s);
4463
s.len() as i32
4564
}
46-
Err(_) => {
47-
cb(PrintLevel::Warn, "Could not format libbpf log string");
65+
Err(e) => {
66+
cb(
67+
PrintLevel::Warn,
68+
&format!("Failed to parse libbpf output: {}", e),
69+
);
4870
-1
4971
}
5072
}
5173
}
5274

53-
/// Set a callback to receive log messages from libbpf, instead of printing them to stderr
75+
/// Set a callback to receive log messages from libbpf, instead of printing them to stderr. Returns
76+
/// the previous callback.
77+
///
78+
/// # Arguments
79+
///
80+
/// * `func` - The callback
5481
///
5582
/// This overrides (and is overridden by) [`ObjectBuilder::debug`]
56-
pub fn set_print(func: PrintCallback) {
57-
let cb = unsafe { std::mem::transmute::<PrintCallback, *mut ()>(func) };
58-
PRINT_CB.store(cb, Ordering::Relaxed);
83+
pub fn set_print(func: PrintCallback) -> PrintCallback {
84+
let cb = to_ptr!(func);
85+
let prev = PRINT_CB.swap(cb, Ordering::Relaxed);
5986
unsafe { libbpf_sys::libbpf_set_print(Some(outer_print_cb)) };
87+
from_ptr(prev)
6088
}

libbpf-rs/tests/test_print.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
//! set_print() and ObjectBuilder::debug() sets global state. The default is to run multiple tests
33
//! in different threads, so this test will always race with the others unless its isolated to a
44
//! different process.
5+
//!
6+
//! For the same reason, all tests here must run serially.
57
68
use libbpf_rs::{set_print, ObjectBuilder, PrintLevel};
9+
use serial_test::serial;
710
use std::sync::atomic::{AtomicBool, Ordering};
811

912
#[test]
13+
#[serial]
1014
fn test_set_print() {
1115
static CORRECT_LEVEL: AtomicBool = AtomicBool::new(false);
1216
static CORRECT_MESSAGE: AtomicBool = AtomicBool::new(false);
@@ -31,3 +35,18 @@ fn test_set_print() {
3135
assert!(correct_level, "Did not capture a warning");
3236
assert!(correct_message, "Did not capture the correct message");
3337
}
38+
39+
#[test]
40+
#[serial]
41+
fn test_set_restore_print() {
42+
fn callback1(_: PrintLevel, _: &str) {
43+
println!("one");
44+
}
45+
fn callback2(_: PrintLevel, _: &str) {
46+
println!("two");
47+
}
48+
49+
set_print(callback1);
50+
assert_eq!(callback1 as usize, set_print(callback2) as usize);
51+
assert_eq!(callback2 as usize, set_print(callback1) as usize);
52+
}

0 commit comments

Comments
 (0)