Skip to content

Commit 9c76991

Browse files
ignatzPThorpe92
authored andcommitted
Add new Connection API to install SQLite update hooks.
1 parent 19b1c7a commit 9c76991

File tree

6 files changed

+152
-10
lines changed

6 files changed

+152
-10
lines changed

libsql/src/connection.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@ use crate::{Result, TransactionBehavior};
1313

1414
pub type AuthHook = Arc<dyn Fn(&AuthContext) -> Authorization>;
1515

16+
pub type UpdateHook = dyn Fn(Op, &str, &str, i64) + Send + Sync;
17+
18+
#[derive(Clone, Copy, Debug, PartialEq)]
19+
pub enum Op {
20+
Insert = 0,
21+
Delete = 1,
22+
Update = 2,
23+
}
24+
1625
#[async_trait::async_trait]
1726
pub(crate) trait Conn {
1827
async fn execute(&self, sql: &str, params: Params) -> Result<u64>;
@@ -58,6 +67,10 @@ pub(crate) trait Conn {
5867
fn authorizer(&self, _hook: Option<AuthHook>) -> Result<()> {
5968
Err(crate::Error::AuthorizerNotSupported)
6069
}
70+
71+
fn add_update_hook(&self, _cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>) -> Result<()> {
72+
Err(crate::Error::UpdateHookNotSupported)
73+
}
6174
}
6275

6376
/// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially
@@ -285,6 +298,10 @@ impl Connection {
285298
pub fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
286299
self.conn.authorizer(hook)
287300
}
301+
302+
pub fn add_update_hook(&self, cb: Box<UpdateHook>) -> Result<()> {
303+
self.conn.add_update_hook(cb)
304+
}
288305
}
289306

290307
impl fmt::Debug for Connection {

libsql/src/errors.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ pub enum Error {
2323
LoadExtensionNotSupported, // Not in rusqlite
2424
#[error("Authorizer is only supported in local databases.")]
2525
AuthorizerNotSupported, // Not in rusqlite
26+
#[error("Update hooks are only supported in local databases.")]
27+
UpdateHookNotSupported, // Not in rusqlite
2628
#[error("Column not found: {0}")]
2729
ColumnNotFound(i32), // Not in rusqlite
2830
#[error("Hrana: `{0}`")]

libsql/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ cfg_hrana! {
180180

181181
pub use self::{
182182
auth::{AuthAction, AuthContext, Authorization},
183-
connection::{AuthHook, BatchRows, Connection},
183+
connection::{AuthHook, BatchRows, Connection, Op},
184184
database::{Builder, Database},
185185
load_extension_guard::LoadExtensionGuard,
186186
rows::{Column, Row, Rows},

libsql/src/local/connection.rs

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ use crate::auth::{AuthAction, AuthContext, Authorization};
44
use crate::connection::AuthHook;
55
use crate::local::rows::BatchedRows;
66
use crate::params::Params;
7-
use crate::{connection::BatchRows, errors};
7+
use crate::{
8+
connection::{BatchRows, Op, UpdateHook},
9+
errors,
10+
};
811
use std::time::Duration;
912

1013
use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};
@@ -15,6 +18,10 @@ use libsql_sys::ffi;
1518
use parking_lot::RwLock;
1619
use std::{ffi::c_int, fmt, path::Path, sync::Arc};
1720

21+
struct Container {
22+
cb: Box<UpdateHook>,
23+
}
24+
1825
/// A connection to a libSQL database.
1926
#[derive(Clone)]
2027
pub struct Connection {
@@ -400,6 +407,24 @@ impl Connection {
400407
})
401408
}
402409

410+
/// Installs update hook
411+
pub fn add_update_hook(&self, cb: Box<UpdateHook>) {
412+
let c = Box::new(Container { cb });
413+
let ptr: *mut Container = std::ptr::from_mut(Box::leak(c));
414+
415+
let old_data = unsafe {
416+
ffi::sqlite3_update_hook(
417+
self.raw,
418+
Some(update_hook_cb),
419+
ptr as *mut ::std::os::raw::c_void,
420+
)
421+
};
422+
423+
if !old_data.is_null() {
424+
let _ = unsafe { Box::from_raw(old_data as *mut Container) };
425+
}
426+
}
427+
403428
pub fn enable_load_extension(&self, onoff: bool) -> Result<()> {
404429
// SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION configration verb accepts 2 additional parameters: an on/off flag and a pointer to an c_int where new state of the parameter will be written (or NULL if reporting back the setting is not needed)
405430
// See: https://sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
@@ -464,7 +489,8 @@ impl Connection {
464489

465490
pub fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
466491
unsafe {
467-
let rc = libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut());
492+
let rc =
493+
libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut());
468494
if rc != ffi::SQLITE_OK {
469495
return Err(crate::errors::Error::SqliteFailure(
470496
rc as std::ffi::c_int,
@@ -484,7 +510,8 @@ impl Connection {
484510
None => (None, std::ptr::null_mut()),
485511
};
486512

487-
let rc = unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) };
513+
let rc =
514+
unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) };
488515
if rc != ffi::SQLITE_OK {
489516
return Err(crate::errors::Error::SqliteFailure(
490517
rc as std::ffi::c_int,
@@ -716,7 +743,7 @@ unsafe extern "C" fn authorizer_callback(
716743

717744
pub(crate) struct WalInsertHandle<'a> {
718745
conn: &'a Connection,
719-
in_session: RwLock<bool>
746+
in_session: RwLock<bool>,
720747
}
721748

722749
impl WalInsertHandle<'_> {
@@ -761,6 +788,28 @@ impl fmt::Debug for Connection {
761788
}
762789
}
763790

791+
#[no_mangle]
792+
extern "C" fn update_hook_cb(
793+
data: *mut ::std::os::raw::c_void,
794+
op: ::std::os::raw::c_int,
795+
db_name: *const ::std::os::raw::c_char,
796+
table_name: *const ::std::os::raw::c_char,
797+
row_id: i64,
798+
) {
799+
let db = unsafe { std::ffi::CStr::from_ptr(db_name).to_string_lossy() };
800+
let table = unsafe { std::ffi::CStr::from_ptr(table_name).to_string_lossy() };
801+
802+
let c = unsafe { &mut *(data as *mut Container) };
803+
let o = match op {
804+
9 => Op::Delete,
805+
18 => Op::Insert,
806+
23 => Op::Update,
807+
_ => unreachable!("Unknown operation {op}"),
808+
};
809+
810+
(*c.cb)(o, &db, &table, row_id);
811+
}
812+
764813
#[cfg(test)]
765814
mod tests {
766815
use crate::{

libsql/src/local/impls.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
use std::sync::Arc;
2-
use std::{fmt, path::Path};
32
use std::time::Duration;
3+
use std::{fmt, path::Path};
44

5-
use crate::connection::BatchRows;
65
use crate::{
7-
connection::{AuthHook, Conn},
6+
connection::{AuthHook, BatchRows, Conn, UpdateHook},
87
params::Params,
98
rows::{ColumnsInner, RowInner, RowsInner},
109
statement::Stmt,
@@ -100,6 +99,10 @@ impl Conn for LibsqlConnection {
10099
fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
101100
self.conn.authorizer(hook)
102101
}
102+
103+
fn add_update_hook(&self, cb: Box<UpdateHook>) -> Result<()> {
104+
Ok(self.conn.add_update_hook(cb))
105+
}
103106
}
104107

105108
impl Drop for LibsqlConnection {

libsql/tests/integration_tests.rs

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ use futures::{StreamExt, TryStreamExt};
44
use libsql::{
55
named_params, params,
66
params::{IntoParams, IntoValue},
7-
AuthAction, Authorization, Connection, Database, Result, Value,
7+
AuthAction, Authorization, Connection, Database, Op, Result, Value,
88
};
99
use rand::distributions::Uniform;
1010
use rand::prelude::*;
1111
use std::collections::HashSet;
12-
use std::sync::Arc;
12+
use std::sync::{Arc, Mutex};
1313

1414
async fn setup() -> Connection {
1515
let db = Database::open(":memory:").unwrap();
@@ -28,6 +28,77 @@ async fn enable_disable_extension() {
2828
conn.load_extension_disable().unwrap();
2929
}
3030

31+
#[tokio::test]
32+
async fn add_update_hook() {
33+
let conn = setup().await;
34+
35+
#[derive(PartialEq, Debug)]
36+
struct Data {
37+
op: Op,
38+
db: String,
39+
table: String,
40+
row_id: i64,
41+
}
42+
43+
let d = Arc::new(Mutex::new(None::<Data>));
44+
45+
let d_clone = d.clone();
46+
conn.add_update_hook(Box::new(move |op, db, table, row_id| {
47+
*d_clone.lock().unwrap() = Some(Data {
48+
op,
49+
db: db.to_string(),
50+
table: table.to_string(),
51+
row_id,
52+
});
53+
}))
54+
.unwrap();
55+
56+
let _ = conn
57+
.execute("INSERT INTO users (id, name) VALUES (2, 'Alice')", ())
58+
.await
59+
.unwrap();
60+
61+
assert_eq!(
62+
*d.lock().unwrap().as_ref().unwrap(),
63+
Data {
64+
op: Op::Insert,
65+
db: "main".to_string(),
66+
table: "users".to_string(),
67+
row_id: 1,
68+
}
69+
);
70+
71+
let _ = conn
72+
.execute("UPDATE users SET name = 'Bob' WHERE id = 2", ())
73+
.await
74+
.unwrap();
75+
76+
assert_eq!(
77+
*d.lock().unwrap().as_ref().unwrap(),
78+
Data {
79+
op: Op::Update,
80+
db: "main".to_string(),
81+
table: "users".to_string(),
82+
row_id: 1,
83+
}
84+
);
85+
86+
let _ = conn
87+
.execute("DELETE FROM users WHERE id = 2", ())
88+
.await
89+
.unwrap();
90+
91+
assert_eq!(
92+
*d.lock().unwrap().as_ref().unwrap(),
93+
Data {
94+
op: Op::Delete,
95+
db: "main".to_string(),
96+
table: "users".to_string(),
97+
row_id: 1,
98+
}
99+
);
100+
}
101+
31102
#[tokio::test]
32103
async fn connection_drops_before_statements() {
33104
let db = Database::open(":memory:").unwrap();

0 commit comments

Comments
 (0)