Skip to content

Commit

Permalink
feat(pkarr): add a batch to LmdbCache to close #100 and fix lru bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Nuhvi committed Dec 1, 2024
1 parent caa3cb5 commit 9d62e69
Showing 1 changed file with 166 additions and 27 deletions.
193 changes: 166 additions & 27 deletions pkarr/src/extra/lmdb_cache.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
//! Persistent [crate::base::cache::Cache] implementation using LMDB's bindings [heed]
use std::{borrow::Cow, fs, path::Path, time::Duration};
use std::{
borrow::Cow,
fs,
path::Path,
sync::{Arc, RwLock},
time::Duration,
};

use byteorder::LittleEndian;
use heed::{types::U64, BoxedError, BytesDecode, BytesEncode, Database, Env, EnvOpenOptions};
use heed::{
types::U64, BoxedError, BytesDecode, BytesEncode, Database, Env, EnvOpenOptions, RwTxn,
};
use libc::{sysconf, _SC_PAGESIZE};

use tracing::debug;
Expand Down Expand Up @@ -71,6 +79,7 @@ pub struct LmdbCache {
signed_packets_table: SignedPacketsTable,
key_to_time_table: KeyToTimeTable,
time_to_key_table: TimeToKeyTable,
batch: Arc<RwLock<Vec<CacheKey>>>,
}

impl LmdbCache {
Expand Down Expand Up @@ -115,6 +124,7 @@ impl LmdbCache {
signed_packets_table,
key_to_time_table,
time_to_key_table,
batch: Arc::new(RwLock::new(vec![])),
};

let clone = instance.clone();
Expand Down Expand Up @@ -149,12 +159,15 @@ impl LmdbCache {
let key_to_time = self.key_to_time_table;
let time_to_key = self.time_to_key_table;

let batch = self.batch.read().expect("LmdbCache::batch.read()");
update_lru(&mut wtxn, packets, key_to_time, time_to_key, &batch)?;

let len = packets.len(&wtxn)? as usize;

if len >= self.capacity {
debug!(?len, ?self.capacity, "Reached cache capacity, deleting extra item.");

let mut iter = time_to_key.rev_iter(&wtxn)?;
let mut iter = time_to_key.iter(&wtxn)?;

if let Some((time, key)) = iter.next().transpose()? {
drop(iter);
Expand Down Expand Up @@ -182,30 +195,12 @@ impl LmdbCache {
}

pub fn internal_get(&self, key: &CacheKey) -> Result<Option<SignedPacket>, heed::Error> {
let mut wtxn = self.env.write_txn()?;

let packets = self.signed_packets_table;
let key_to_time = self.key_to_time_table;
let time_to_key = self.time_to_key_table;

if let Some(signed_packet) = packets.get(&wtxn, key)? {
if let Some(time) = key_to_time.get(&wtxn, key)? {
time_to_key.delete(&mut wtxn, &time)?;
};

let new_time = Timestamp::now();
self.batch
.write()
.expect("LmdbCache::batch.write()")
.push(*key);

time_to_key.put(&mut wtxn, &new_time.as_u64(), key)?;
key_to_time.put(&mut wtxn, key, &new_time.as_u64())?;

wtxn.commit()?;

return Ok(Some(signed_packet));
}

wtxn.commit()?;

Ok(None)
self.internal_get_read_only(key)
}

pub fn internal_get_read_only(
Expand All @@ -224,6 +219,29 @@ impl LmdbCache {
}
}

fn update_lru(
wtxn: &mut RwTxn,
packets: SignedPacketsTable,
key_to_time: KeyToTimeTable,
time_to_key: TimeToKeyTable,
to_update: &[CacheKey],
) -> Result<(), heed::Error> {
for key in to_update {
if packets.get(wtxn, key)?.is_some() {
if let Some(time) = key_to_time.get(wtxn, key)? {
time_to_key.delete(wtxn, &time)?;
};

let new_time = Timestamp::now();

time_to_key.put(wtxn, &new_time.as_u64(), key)?;
key_to_time.put(wtxn, key, &new_time.as_u64())?;
}
}

Ok(())
}

impl Cache for LmdbCache {
fn len(&self) -> usize {
match self.internal_len() {
Expand Down Expand Up @@ -278,7 +296,7 @@ pub enum Error {

#[cfg(test)]
mod tests {
use std::usize;
use crate::Keypair;

use super::*;

Expand All @@ -288,4 +306,125 @@ mod tests {

LmdbCache::new(&env_path, usize::MAX).unwrap();
}

#[test]
fn lru_capacity() {
let env_path = std::env::temp_dir().join(Timestamp::now().to_string());

let cache = LmdbCache::new(&env_path, 2).unwrap();

let mut keys = vec![];

for i in 0..2 {
let signed_packet = SignedPacket::builder()
.txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), i)
.sign(&Keypair::random())
.unwrap();

let key = CacheKey::from(signed_packet.public_key());
cache.put(&key, &signed_packet);

keys.push((key, signed_packet));
}

assert_eq!(
cache.get_read_only(&keys.first().unwrap().0).unwrap(),
keys.first().unwrap().1,
"first key saved"
);
assert_eq!(
cache.get_read_only(&keys.last().unwrap().0).unwrap(),
keys.last().unwrap().1,
"second key saved"
);

assert_eq!(cache.len(), 2);

// Put another one, effectively deleting the oldest.
let signed_packet = SignedPacket::builder()
.txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 3)
.sign(&Keypair::random())
.unwrap();
let key = CacheKey::from(signed_packet.public_key());
cache.put(&key, &signed_packet);

assert_eq!(cache.len(), 2);

assert!(
cache.get_read_only(&keys.first().unwrap().0).is_none(),
"oldest key dropped"
);
assert_eq!(
cache.get_read_only(&keys.last().unwrap().0).unwrap(),
keys.last().unwrap().1,
"more recent key survived"
);
assert_eq!(
cache.get_read_only(&key).unwrap(),
signed_packet,
"most recent key survived"
)
}

#[test]
fn lru_capacity_refresh_oldest() {
let env_path = std::env::temp_dir().join(Timestamp::now().to_string());

let cache = LmdbCache::new(&env_path, 2).unwrap();

let mut keys = vec![];

for i in 0..2 {
let signed_packet = SignedPacket::builder()
.txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), i)
.sign(&Keypair::random())
.unwrap();

let key = CacheKey::from(signed_packet.public_key());
cache.put(&key, &signed_packet);

keys.push((key, signed_packet));
}

assert_eq!(
cache.get_read_only(&keys.first().unwrap().0).unwrap(),
keys.first().unwrap().1,
"first key saved"
);
assert_eq!(
cache.get_read_only(&keys.last().unwrap().0).unwrap(),
keys.last().unwrap().1,
"second key saved"
);

// refresh the oldest
cache.get(&keys.first().unwrap().0).unwrap();

assert_eq!(cache.len(), 2);

// Put another one, effectively deleting the oldest.
let signed_packet = SignedPacket::builder()
.txt("foo".try_into().unwrap(), "bar".try_into().unwrap(), 3)
.sign(&Keypair::random())
.unwrap();
let key = CacheKey::from(signed_packet.public_key());
cache.put(&key, &signed_packet);

assert_eq!(cache.len(), 2);

assert!(
cache.get_read_only(&keys.last().unwrap().0).is_none(),
"oldest key dropped"
);
assert_eq!(
cache.get_read_only(&keys.first().unwrap().0).unwrap(),
keys.first().unwrap().1,
"refreshed key survived"
);
assert_eq!(
cache.get_read_only(&key).unwrap(),
signed_packet,
"most recent key survived"
)
}
}

0 comments on commit 9d62e69

Please sign in to comment.