Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions libsql/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ use crate::{Result, TransactionBehavior};

pub type AuthHook = Arc<dyn Fn(&AuthContext) -> Authorization>;

pub type UpdateHook = dyn Fn(Op, &str, &str, i64) + Send + Sync;

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Op {
Insert = 0,
Delete = 1,
Update = 2,
}

#[async_trait::async_trait]
pub(crate) trait Conn {
async fn execute(&self, sql: &str, params: Params) -> Result<u64>;
Expand Down Expand Up @@ -58,6 +67,10 @@ pub(crate) trait Conn {
fn authorizer(&self, _hook: Option<AuthHook>) -> Result<()> {
Err(crate::Error::AuthorizerNotSupported)
}

fn add_update_hook(&self, _cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>) -> Result<()> {
Err(crate::Error::UpdateHookNotSupported)
}
}

/// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially
Expand Down Expand Up @@ -285,6 +298,10 @@ impl Connection {
pub fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
self.conn.authorizer(hook)
}

pub fn add_update_hook(&self, cb: Box<UpdateHook>) -> Result<()> {
self.conn.add_update_hook(cb)
}
}

impl fmt::Debug for Connection {
Expand Down
2 changes: 2 additions & 0 deletions libsql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub enum Error {
LoadExtensionNotSupported, // Not in rusqlite
#[error("Authorizer is only supported in local databases.")]
AuthorizerNotSupported, // Not in rusqlite
#[error("Update hooks are only supported in local databases.")]
UpdateHookNotSupported, // Not in rusqlite
#[error("Column not found: {0}")]
ColumnNotFound(i32), // Not in rusqlite
#[error("Hrana: `{0}`")]
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ cfg_hrana! {

pub use self::{
auth::{AuthAction, AuthContext, Authorization},
connection::{AuthHook, BatchRows, Connection},
connection::{AuthHook, BatchRows, Connection, Op},
database::{Builder, Database},
load_extension_guard::LoadExtensionGuard,
rows::{Column, Row, Rows},
Expand Down
57 changes: 53 additions & 4 deletions libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ use crate::auth::{AuthAction, AuthContext, Authorization};
use crate::connection::AuthHook;
use crate::local::rows::BatchedRows;
use crate::params::Params;
use crate::{connection::BatchRows, errors};
use crate::{
connection::{BatchRows, Op, UpdateHook},
errors,
};
use std::time::Duration;

use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};
Expand All @@ -15,6 +18,10 @@ use libsql_sys::ffi;
use parking_lot::RwLock;
use std::{ffi::c_int, fmt, path::Path, sync::Arc};

struct Container {
cb: Box<UpdateHook>,
}

/// A connection to a libSQL database.
#[derive(Clone)]
pub struct Connection {
Expand Down Expand Up @@ -400,6 +407,24 @@ impl Connection {
})
}

/// Installs update hook
pub fn add_update_hook(&self, cb: Box<UpdateHook>) {
let c = Box::new(Container { cb });
let ptr: *mut Container = std::ptr::from_mut(Box::leak(c));

let old_data = unsafe {
ffi::sqlite3_update_hook(
self.raw,
Some(update_hook_cb),
ptr as *mut ::std::os::raw::c_void,
)
};

if !old_data.is_null() {
let _ = unsafe { Box::from_raw(old_data as *mut Container) };
}
}

pub fn enable_load_extension(&self, onoff: bool) -> Result<()> {
// SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION configration verb accepts 2 additional parameters: an on/off flag and a pointer to an c_int where new state of the parameter will be written (or NULL if reporting back the setting is not needed)
// See: https://sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
Expand Down Expand Up @@ -464,7 +489,8 @@ impl Connection {

pub fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
unsafe {
let rc = libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut());
let rc =
libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), None, std::ptr::null_mut());
if rc != ffi::SQLITE_OK {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
Expand All @@ -484,7 +510,8 @@ impl Connection {
None => (None, std::ptr::null_mut()),
};

let rc = unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) };
let rc =
unsafe { libsql_sys::ffi::sqlite3_set_authorizer(self.handle(), callback, user_data) };
if rc != ffi::SQLITE_OK {
return Err(crate::errors::Error::SqliteFailure(
rc as std::ffi::c_int,
Expand Down Expand Up @@ -716,7 +743,7 @@ unsafe extern "C" fn authorizer_callback(

pub(crate) struct WalInsertHandle<'a> {
conn: &'a Connection,
in_session: RwLock<bool>
in_session: RwLock<bool>,
}

impl WalInsertHandle<'_> {
Expand Down Expand Up @@ -761,6 +788,28 @@ impl fmt::Debug for Connection {
}
}

#[no_mangle]
extern "C" fn update_hook_cb(
data: *mut ::std::os::raw::c_void,
op: ::std::os::raw::c_int,
db_name: *const ::std::os::raw::c_char,
table_name: *const ::std::os::raw::c_char,
row_id: i64,
) {
let db = unsafe { std::ffi::CStr::from_ptr(db_name).to_string_lossy() };
let table = unsafe { std::ffi::CStr::from_ptr(table_name).to_string_lossy() };

let c = unsafe { &mut *(data as *mut Container) };
let o = match op {
9 => Op::Delete,
18 => Op::Insert,
23 => Op::Update,
_ => unreachable!("Unknown operation {op}"),
};

(*c.cb)(o, &db, &table, row_id);
}

#[cfg(test)]
mod tests {
use crate::{
Expand Down
9 changes: 6 additions & 3 deletions libsql/src/local/impls.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
use std::sync::Arc;
use std::{fmt, path::Path};
use std::time::Duration;
use std::{fmt, path::Path};

use crate::connection::BatchRows;
use crate::{
connection::{AuthHook, Conn},
connection::{AuthHook, BatchRows, Conn, UpdateHook},
params::Params,
rows::{ColumnsInner, RowInner, RowsInner},
statement::Stmt,
Expand Down Expand Up @@ -100,6 +99,10 @@ impl Conn for LibsqlConnection {
fn authorizer(&self, hook: Option<AuthHook>) -> Result<()> {
self.conn.authorizer(hook)
}

fn add_update_hook(&self, cb: Box<UpdateHook>) -> Result<()> {
Ok(self.conn.add_update_hook(cb))
}
}

impl Drop for LibsqlConnection {
Expand Down
75 changes: 73 additions & 2 deletions libsql/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use futures::{StreamExt, TryStreamExt};
use libsql::{
named_params, params,
params::{IntoParams, IntoValue},
AuthAction, Authorization, Connection, Database, Result, Value,
AuthAction, Authorization, Connection, Database, Op, Result, Value,
};
use rand::distributions::Uniform;
use rand::prelude::*;
use std::collections::HashSet;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

async fn setup() -> Connection {
let db = Database::open(":memory:").unwrap();
Expand All @@ -28,6 +28,77 @@ async fn enable_disable_extension() {
conn.load_extension_disable().unwrap();
}

#[tokio::test]
async fn add_update_hook() {
let conn = setup().await;

#[derive(PartialEq, Debug)]
struct Data {
op: Op,
db: String,
table: String,
row_id: i64,
}

let d = Arc::new(Mutex::new(None::<Data>));

let d_clone = d.clone();
conn.add_update_hook(Box::new(move |op, db, table, row_id| {
*d_clone.lock().unwrap() = Some(Data {
op,
db: db.to_string(),
table: table.to_string(),
row_id,
});
}))
.unwrap();

let _ = conn
.execute("INSERT INTO users (id, name) VALUES (2, 'Alice')", ())
.await
.unwrap();

assert_eq!(
*d.lock().unwrap().as_ref().unwrap(),
Data {
op: Op::Insert,
db: "main".to_string(),
table: "users".to_string(),
row_id: 1,
}
);

let _ = conn
.execute("UPDATE users SET name = 'Bob' WHERE id = 2", ())
.await
.unwrap();

assert_eq!(
*d.lock().unwrap().as_ref().unwrap(),
Data {
op: Op::Update,
db: "main".to_string(),
table: "users".to_string(),
row_id: 1,
}
);

let _ = conn
.execute("DELETE FROM users WHERE id = 2", ())
.await
.unwrap();

assert_eq!(
*d.lock().unwrap().as_ref().unwrap(),
Data {
op: Op::Delete,
db: "main".to_string(),
table: "users".to_string(),
row_id: 1,
}
);
}

#[tokio::test]
async fn connection_drops_before_statements() {
let db = Database::open(":memory:").unwrap();
Expand Down
Loading