Skip to content

Commit 8e34664

Browse files
committed
Improve get|set auxdata. credit: asg017#22
1 parent 2c5c049 commit 8e34664

File tree

2 files changed

+73
-11
lines changed

2 files changed

+73
-11
lines changed

src/api.rs

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,23 +372,26 @@ pub fn result_pointer<T>(context: *mut sqlite3_context, name: &[u8], object: T)
372372
};
373373
}
374374

375-
// TODO maybe take in a Box<T>?
376375
/// [`sqlite3_set_auxdata`](https://www.sqlite.org/c3ref/get_auxdata.html)
377-
pub fn auxdata_set(
378-
context: *mut sqlite3_context,
379-
col: i32,
380-
p: *mut c_void,
381-
d: Option<unsafe extern "C" fn(*mut c_void)>,
382-
) {
376+
pub fn auxdata_set<T>(context: *mut sqlite3_context, col: i32, p: Box<T>) {
377+
unsafe extern "C" fn cleanup<U>(p: *mut c_void) {
378+
drop(Box::from_raw(p.cast::<U>()));
379+
}
380+
381+
let raw = Box::into_raw(p).cast::<c_void>();
383382
unsafe {
384-
sqlite3ext_set_auxdata(context, col, p, d);
383+
sqlite3ext_set_auxdata(context, col, raw, Some(cleanup::<T>));
385384
}
386385
}
387386

388-
// TODO maybe return a Box<T>?
389387
/// [`sqlite3_get_auxdata`](https://www.sqlite.org/c3ref/get_auxdata.html)
390-
pub fn auxdata_get(context: *mut sqlite3_context, col: i32) -> *mut c_void {
391-
unsafe { sqlite3ext_get_auxdata(context, col) }
388+
pub fn auxdata_get<'a, T>(context: *mut sqlite3_context, col: i32) -> Option<&'a mut T> {
389+
let ptr = unsafe { sqlite3ext_get_auxdata(context, col).cast::<T>() };
390+
if ptr.is_null() {
391+
None
392+
} else {
393+
Some(unsafe { &mut *ptr })
394+
}
392395
}
393396

394397
pub fn context_db_handle(context: *mut sqlite3_context) -> *mut sqlite3 {

tests/test_auxdata.rs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
use sqlite_loadable::prelude::*;
2+
use sqlite_loadable::{api, define_scalar_function, Result};
3+
4+
pub fn check_auxdata(context: *mut sqlite3_context, values: &[*mut sqlite3_value]) -> Result<()> {
5+
let label = api::value_text(values.first().unwrap()).unwrap();
6+
let value = api::value_text(values.get(1).unwrap()).unwrap();
7+
8+
assert!(api::auxdata_get::<String>(context, 1).is_none());
9+
10+
let b = Box::new(String::from(value));
11+
api::auxdata_set::<String>(context, 1, b);
12+
13+
let entry = api::auxdata_get::<String>(context, 1).unwrap();
14+
assert!(entry == value);
15+
16+
api::result_text(context, format!("{label}={value}")).unwrap();
17+
18+
Ok(())
19+
}
20+
21+
#[sqlite_entrypoint]
22+
pub fn sqlite3_test_auxdata_init(db: *mut sqlite3) -> Result<()> {
23+
define_scalar_function(db, "check_auxdata", 2, check_auxdata, FunctionFlags::UTF8)?;
24+
Ok(())
25+
}
26+
27+
#[cfg(test)]
28+
mod tests {
29+
use super::*;
30+
31+
use rusqlite::{ffi::sqlite3_auto_extension, Connection};
32+
33+
#[test]
34+
fn test_rusqlite_auto_extension() {
35+
unsafe {
36+
sqlite3_auto_extension(Some(std::mem::transmute(
37+
sqlite3_test_auxdata_init as *const (),
38+
)));
39+
}
40+
41+
let conn = Connection::open_in_memory().unwrap();
42+
43+
// NOTE: even nested expressions are evaluated in different contexts leading to an
44+
// auxdata_get miss. auxdata_get/set is not suitable for naive caching across function
45+
// evaluations.
46+
let result: String = conn
47+
.query_row(
48+
"SELECT (check_auxdata(?1, check_auxdata(?2, ?3)))",
49+
("outer_label", "inner_label", "value"),
50+
|row| {
51+
println!("ROW {row:?}");
52+
row.get(0)
53+
},
54+
)
55+
.unwrap();
56+
57+
assert_eq!(result, "outer_label=inner_label=value");
58+
}
59+
}

0 commit comments

Comments
 (0)