Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cache): add in-memory lru cache for signature verification #1625

Merged
merged 13 commits into from
Feb 18, 2025
29 changes: 29 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions mls_validation_service/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
[package]
build = "build.rs"
edition = "2021"
license.workspace = true
name = "mls_validation_service"
version = "0.1.4"
build = "build.rs"
license.workspace = true

[[bin]] # Bin to run the Validation Service
name = "mls-validation-service"
Expand All @@ -13,17 +13,20 @@ path = "src/main.rs"
vergen-git2 = { workspace = true, features = ["build"] }

[dependencies]
async-trait.workspace = true
clap = { version = "4.4.6", features = ["derive"] }
ethers = { workspace = true }
futures = { workspace = true }
hex = { workspace = true }
lru = "0.13.0"
openmls = { workspace = true }
openmls_rust_crypto = { workspace = true }
thiserror.workspace = true
tokio = { workspace = true, features = ["signal", "rt-multi-thread"] }
tonic = { workspace = true }
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter", "ansi"] }
url.workspace = true
warp = "0.3.6"
xmtp_cryptography = { path = "../xmtp_cryptography" }
xmtp_id.workspace = true
Expand All @@ -34,9 +37,9 @@ xmtp_proto = { path = "../xmtp_proto", features = ["proto_full", "convert"] }
anyhow.workspace = true
ethers.workspace = true
rand = { workspace = true }
xmtp_common = { workspace = true, features = ["test-utils"] }
xmtp_id = { workspace = true, features = ["test-utils"] }
xmtp_mls = { workspace = true, features = ["test-utils"] }
xmtp_common = { workspace = true, features = ["test-utils"] }

[features]
test-utils = ["xmtp_id/test-utils"]
187 changes: 187 additions & 0 deletions mls_validation_service/src/cached_signature_verifier.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
use lru::LruCache;
use std::num::NonZeroUsize;
use tokio::sync::Mutex;

use ethers::types::{BlockNumber, Bytes};
use xmtp_id::associations::AccountId;
use xmtp_id::scw_verifier::{SmartContractSignatureVerifier, ValidationResponse, VerifierError};

#[derive(Hash, Eq, PartialEq, Clone, Debug)]
pub struct CacheKey {
pub address: String,
pub chain_id: String,
pub hash: [u8; 32],
pub signature: Vec<u8>,
pub block_number: Option<u64>,
}

impl CacheKey {
pub fn new(
account_id: &AccountId,
hash: [u8; 32],
signature: &Bytes,
block_number: Option<BlockNumber>,
) -> Self {
let block_number_u64 = block_number.and_then(|bn| bn.as_number().map(|n| n.as_u64()));
let address = account_id.get_account_address().to_string();
let chain_id = account_id.get_chain_id().to_string();

Self {
chain_id,
address,
hash,
signature: signature.to_vec(),
block_number: block_number_u64,
}
}
}

/// A cached smart contract verifier.
///
/// This wraps MultiSmartContractSignatureVerifier (or any other verifier
/// implementing SmartContractSignatureVerifier) and adds an in-memory LRU cache.
pub struct CachedSmartContractSignatureVerifier {
verifier: Box<dyn SmartContractSignatureVerifier>,
cache: Mutex<LruCache<CacheKey, ValidationResponse>>,
}

impl CachedSmartContractSignatureVerifier {
pub fn new(
verifier: impl SmartContractSignatureVerifier + 'static,
cache_size: usize,
) -> Result<Self, VerifierError> {
if cache_size == 0 {
return Err(VerifierError::InvalidCacheSize(cache_size));
}
Ok(Self {
verifier: Box::new(verifier),
cache: Mutex::new(LruCache::new(NonZeroUsize::new(cache_size).unwrap())),
})
}
}

#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
impl SmartContractSignatureVerifier for CachedSmartContractSignatureVerifier {
async fn is_valid_signature(
&self,
account_id: AccountId,
hash: [u8; 32],
signature: Bytes,
block_number: Option<BlockNumber>,
) -> Result<ValidationResponse, VerifierError> {
let key = CacheKey::new(&account_id, hash, &signature, block_number);

if let Some(cached_response) = {
let mut cache = self.cache.lock().await;
cache.get(&key).cloned()
} {
return Ok(cached_response);
}

let response = self
.verifier
.is_valid_signature(account_id, hash, signature, block_number)
.await?;

let mut cache = self.cache.lock().await;
cache.put(key, response.clone());

Ok(response)
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use url::Url;
use xmtp_id::scw_verifier::{
MultiSmartContractSignatureVerifier, SmartContractSignatureVerifier, ValidationResponse,
VerifierError,
};

#[test]
fn test_cache_eviction() {
let mut cache: LruCache<CacheKey, ValidationResponse> =
LruCache::new(NonZeroUsize::new(1).unwrap());

let account_id1 = AccountId::new(String::from("chain1"), String::from("account1"));
let account_id2 = AccountId::new(String::from("chain1"), String::from("account2"));
let hash = [0u8; 32];
let bytes = Bytes::from(vec![1, 2, 3]);
let block_number = Some(BlockNumber::Number(1.into()));

let key1 = CacheKey::new(&account_id1, hash, &bytes, block_number);
let key2 = CacheKey::new(&account_id2, hash, &bytes, block_number);
assert_ne!(key1, key2);

let val1: ValidationResponse = ValidationResponse {
is_valid: true,
block_number: Some(1),
error: None,
};
let val2: ValidationResponse = ValidationResponse {
is_valid: true,
block_number: Some(2),
error: None,
};

cache.put(key1.clone(), val1.clone());
let response = cache.get(&key1).unwrap();
assert_eq!(response.is_valid, val1.is_valid);
assert_eq!(response.block_number, val1.block_number);

cache.put(key2.clone(), val2.clone());
assert!(cache.get(&key1).is_none());

// And key2 is correctly cached.
let response2 = cache.get(&key2).unwrap();
assert_eq!(response2.is_valid, val2.is_valid);
assert_eq!(response2.block_number, val2.block_number);
}

#[test]
fn test_invalid_cache_size() {
let urls: HashMap<String, Url> = HashMap::new();
let scw_verifier = MultiSmartContractSignatureVerifier::new(urls)
.expect("Failed to create MultiSmartContractSignatureVerifier");

let err = CachedSmartContractSignatureVerifier::new(scw_verifier, 0);
if let Err(VerifierError::InvalidCacheSize(size)) = err {
assert_eq!(size, 0);
} else {
panic!("Expected a VerifierError::InvalidCacheSize");
}
}

#[tokio::test]
async fn test_missing_verifier() {
//
let verifiers = std::collections::HashMap::new();
let multi_verifier = MultiSmartContractSignatureVerifier::new(verifiers).unwrap();
let cached_verifier = CachedSmartContractSignatureVerifier::new(multi_verifier, 1).unwrap();

let account_id = AccountId::new("missing".to_string(), "account1".to_string());
let hash = [0u8; 32];
let signature = Bytes::from(vec![1, 2, 3]);
let block_number = Some(BlockNumber::Number(1.into()));

let result = cached_verifier
.is_valid_signature(account_id, hash, signature, block_number)
.await;
assert!(result.is_err());

match result {
Err(VerifierError::Provider(provider_error)) => {
assert_eq!(
provider_error.to_string(),
"custom error: Verifier not present"
);
}
_ => {
panic!("Expected a VerifierError::Provider error.");
}
}
}
}
4 changes: 4 additions & 0 deletions mls_validation_service/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ pub(crate) struct Args {
// A path to a json file in the same format as chain_urls_default.json in the codebase.
#[arg(long)]
pub(crate) chain_urls: Option<String>,

// The size of the cache to use for the smart contract signature verifier.
#[arg(long, default_value_t = 20)]
pub(crate) cache_size: usize,
}
10 changes: 8 additions & 2 deletions mls_validation_service/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
mod cached_signature_verifier;
mod config;
mod handlers;
mod health_check;
mod version;

use crate::cached_signature_verifier::CachedSmartContractSignatureVerifier;
use crate::version::get_version;
use clap::Parser;
use config::Args;
Expand Down Expand Up @@ -39,17 +41,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = format!("0.0.0.0:{}", args.port).parse()?;
info!("Starting validation service on port {:?}", args.port);
info!("Starting health check on port {:?}", args.health_check_port);
info!("Cache size: {:?}", args.cache_size);

let health_server = health_check_server(args.health_check_port as u16);

let scw_verifier = match args.chain_urls {
let verifier = match args.chain_urls {
Some(path) => MultiSmartContractSignatureVerifier::new_from_file(path)?,
None => MultiSmartContractSignatureVerifier::new_from_env()?,
};

let cached_verifier: CachedSmartContractSignatureVerifier =
CachedSmartContractSignatureVerifier::new(verifier, args.cache_size)?;

let grpc_server = Server::builder()
.add_service(ValidationApiServer::new(ValidationService::new(
scw_verifier,
cached_verifier,
)))
.serve_with_shutdown(addr, async {
wait_for_quit().await;
Expand Down
9 changes: 5 additions & 4 deletions xmtp_id/src/scw_verifier/mod.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
mod chain_rpc_verifier;
mod remote_signature_verifier;

use crate::associations::AccountId;
pub use chain_rpc_verifier::*;
use ethers::{
providers::{Http, Provider, ProviderError},
types::{BlockNumber, Bytes},
};
pub use remote_signature_verifier::*;
use std::{collections::HashMap, fs, path::Path, sync::Arc};
use thiserror::Error;
use tracing::info;
use url::Url;

pub use chain_rpc_verifier::*;
pub use remote_signature_verifier::*;

static DEFAULT_CHAIN_URLS: &str = include_str!("chain_urls_default.json");

#[derive(Debug, Error)]
Expand All @@ -38,6 +36,8 @@ pub enum VerifierError {
MalformedEipUrl,
#[error(transparent)]
Api(#[from] xmtp_api::Error),
#[error("invalid cache size: {0}")]
InvalidCacheSize(usize),
}

#[cfg(not(target_arch = "wasm32"))]
Expand Down Expand Up @@ -121,6 +121,7 @@ where
}
}

#[derive(Clone)]
pub struct ValidationResponse {
pub is_valid: bool,
pub block_number: Option<u64>,
Expand Down
Loading