Skip to content

Commit a4d74a9

Browse files
committed
runtime: introduce cache for tokio runtimes
As @Lorak-mmk noted in the review, runtimes are currently not shared between `CassCluster` instances, which leads to possible many tokio runtimes being created in the application, with possibly a lot of threads. This commit introduces a cache for tokio runtimes, which is encapsulated in the global `Runtimes` struct. 1. `CassCluster` now does not store a `tokio::runtime::Runtime` directly, but rather an optional number of threads in the runtime. 2. The `Runtimes` struct is a global cache for tokio runtimes. It allows to get a default runtime or a runtime with a specified number of threads. Upon `cass_session_connect`, if a runtime is not created yet, it will create a new one and cache it for future use. The handling of the cache is fully transparent to the user of the abstraction. `CassSession` since then holds an `Arc<Runtime>`. 3. Once all `CassSession` instances that reference a runtime are dropped, the runtime is also dropped. This is done by storing weak pointers to runtimes in the `Runtimes` struct. Interesting to note: as Weak pointers keep the Arc allocation alive, a workflow that for consecutive `i`s connects a `CassSession` with a runtime with `i` threads and then drops it, will lead to space leaks. This is an artificial case, though. Remember that while the allocation will be still kept alive, the runtime itself will not be running, as it is dropped when the last `CassSession` referencing it is dropped.
1 parent b3d748d commit a4d74a9

File tree

4 files changed

+106
-31
lines changed

4 files changed

+106
-31
lines changed

scylla-rust-wrapper/src/cluster.rs

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::load_balancing::{
77
CassHostFilter, DcRestriction, LoadBalancingConfig, LoadBalancingKind,
88
};
99
use crate::retry_policy::CassRetryPolicy;
10+
use crate::runtime::RUNTIMES;
1011
use crate::ssl::CassSsl;
1112
use crate::timestamp_generator::CassTimestampGen;
1213
use crate::types::*;
@@ -82,7 +83,11 @@ const DRIVER_NAME: &str = "ScyllaDB Cpp-Rust Driver";
8283
const DRIVER_VERSION: &str = env!("CARGO_PKG_VERSION");
8384

8485
pub struct CassCluster {
85-
runtime: Arc<tokio::runtime::Runtime>,
86+
/// Number of threads in the tokio runtime thread pool.
87+
///
88+
/// Specified with `cass_cluster_set_num_threads_io`.
89+
/// If not set, the default tokio runtime is used.
90+
num_threads_io: Option<usize>,
8691

8792
session_builder: SessionBuilder,
8893
default_execution_profile_builder: ExecutionProfileBuilder,
@@ -101,8 +106,20 @@ pub struct CassCluster {
101106
}
102107

103108
impl CassCluster {
104-
pub(crate) fn get_runtime(&self) -> &Arc<tokio::runtime::Runtime> {
105-
&self.runtime
109+
/// Gets the runtime that has been set for the cluster.
110+
/// If no runtime has been set yet, it creates a default runtime
111+
/// and makes it cached in the global `Runtimes` instance.
112+
pub(crate) fn get_runtime(&self) -> Arc<tokio::runtime::Runtime> {
113+
let mut runtimes = RUNTIMES.lock().unwrap();
114+
115+
if let Some(num_threads_io) = self.num_threads_io {
116+
// If the number of threads is set, we create a runtime with that number of threads.
117+
runtimes.n_thread_runtime(num_threads_io)
118+
} else {
119+
// Otherwise, we use the default runtime.
120+
runtimes.default_runtime()
121+
}
122+
.unwrap_or_else(|err| panic!("Failed to create an async runtime: {err}"))
106123
}
107124

108125
pub(crate) fn execution_profile_map(&self) -> &HashMap<ExecProfileName, CassExecProfile> {
@@ -184,12 +201,6 @@ impl CassCluster {
184201

185202
#[unsafe(no_mangle)]
186203
pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster, CMut> {
187-
let Ok(default_runtime) = tokio::runtime::Runtime::new()
188-
.inspect_err(|e| tracing::error!("Failed to create async runtime: {}", e))
189-
else {
190-
return CassPtr::null_mut();
191-
};
192-
193204
let default_execution_profile_builder = ExecutionProfileBuilder::default()
194205
.consistency(DEFAULT_CONSISTENCY)
195206
.serial_consistency(DEFAULT_SERIAL_CONSISTENCY)
@@ -321,7 +332,7 @@ pub unsafe extern "C" fn cass_cluster_new() -> CassOwnedExclusivePtr<CassCluster
321332
};
322333

323334
BoxFFI::into_ptr(Box::new(CassCluster {
324-
runtime: Arc::new(default_runtime),
335+
num_threads_io: None,
325336

326337
session_builder: default_session_builder,
327338
port: 9042,
@@ -1554,25 +1565,7 @@ pub unsafe extern "C" fn cass_cluster_set_num_threads_io(
15541565
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
15551566
};
15561567

1557-
let runtime_res = match num_threads {
1558-
0 => tokio::runtime::Builder::new_current_thread()
1559-
.enable_all()
1560-
.build(),
1561-
n => tokio::runtime::Builder::new_multi_thread()
1562-
.worker_threads(n as usize)
1563-
.enable_all()
1564-
.build(),
1565-
};
1566-
1567-
let runtime = match runtime_res {
1568-
Ok(runtime) => runtime,
1569-
Err(err) => {
1570-
tracing::error!("Failed to create async runtime: {}", err);
1571-
return CassError::CASS_ERROR_LIB_BAD_PARAMS;
1572-
}
1573-
};
1574-
1575-
cluster.runtime = Arc::new(runtime);
1568+
cluster.num_threads_io = Some(num_threads as usize);
15761569

15771570
CassError::CASS_OK
15781571
}

scylla-rust-wrapper/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub(crate) mod misc;
3030
pub(crate) mod prepared;
3131
pub(crate) mod query_result;
3232
pub(crate) mod retry_policy;
33+
pub(crate) mod runtime;
3334
#[cfg(test)]
3435
mod ser_de_tests;
3536
pub(crate) mod session;

scylla-rust-wrapper/src/runtime.rs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
//! Manages tokio runtimes for the application.
2+
//!
3+
//! Runtime is per-cluster and can be changed with `cass_cluster_set_num_threads_io`.
4+
5+
use std::{
6+
collections::HashMap,
7+
sync::{Arc, Weak},
8+
};
9+
10+
use tokio::runtime::Runtime;
11+
12+
/// Manages tokio runtimes for the application.
13+
///
14+
/// Runtime is per-cluster and can be changed with `cass_cluster_set_num_threads_io`.
15+
/// Once a runtime is created, it is cached for future use.
16+
/// Once all `CassSession` instances that reference the runtime are dropped,
17+
/// the runtime is also dropped.
18+
pub(crate) struct Runtimes {
19+
// Weak pointers are used to make runtimes dropped once all `CassSession` instances
20+
// that reference them are freed.
21+
default_runtime: Option<Weak<Runtime>>,
22+
// This is Option to allow creating a static instance of Runtimes.
23+
// (`HashMap::new` is not `const`).
24+
n_thread_runtimes: Option<HashMap<usize, Weak<Runtime>>>,
25+
}
26+
27+
pub(crate) static RUNTIMES: std::sync::Mutex<Runtimes> = {
28+
std::sync::Mutex::new(Runtimes {
29+
default_runtime: None,
30+
n_thread_runtimes: None,
31+
})
32+
};
33+
34+
impl Runtimes {
35+
fn cached_or_new_runtime(
36+
weak_runtime: &mut Weak<Runtime>,
37+
create_runtime: impl FnOnce() -> Result<Arc<Runtime>, std::io::Error>,
38+
) -> Result<Arc<Runtime>, std::io::Error> {
39+
match weak_runtime.upgrade() {
40+
Some(cached_runtime) => Ok(cached_runtime),
41+
None => {
42+
let runtime = create_runtime()?;
43+
*weak_runtime = Arc::downgrade(&runtime);
44+
Ok(runtime)
45+
}
46+
}
47+
}
48+
49+
/// Returns a default tokio runtime.
50+
///
51+
/// If it's not created yet, it will create a new one with the default configuration
52+
/// and cache it for future use.
53+
pub(crate) fn default_runtime(&mut self) -> Result<Arc<Runtime>, std::io::Error> {
54+
let default_runtime_slot = self.default_runtime.get_or_insert_with(Weak::new);
55+
Self::cached_or_new_runtime(default_runtime_slot, || Runtime::new().map(Arc::new))
56+
}
57+
58+
/// Returns a tokio runtime with `n_threads` worker threads.
59+
///
60+
/// If it's not created yet, it will create a new one and cache it for future use.
61+
pub(crate) fn n_thread_runtime(
62+
&mut self,
63+
n_threads: usize,
64+
) -> Result<Arc<Runtime>, std::io::Error> {
65+
let n_thread_runtimes = self.n_thread_runtimes.get_or_insert_with(HashMap::new);
66+
let n_thread_runtime_slot = n_thread_runtimes.entry(n_threads).or_default();
67+
68+
Self::cached_or_new_runtime(n_thread_runtime_slot, || {
69+
match n_threads {
70+
0 => tokio::runtime::Builder::new_current_thread()
71+
.enable_all()
72+
.build(),
73+
n => tokio::runtime::Builder::new_multi_thread()
74+
.worker_threads(n)
75+
.enable_all()
76+
.build(),
77+
}
78+
.map(Arc::new)
79+
})
80+
}
81+
}

scylla-rust-wrapper/src/session.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ impl CassConnectedSession {
7979
let cluster_client_id = cluster.get_client_id();
8080

8181
let fut = Self::connect_fut(
82-
Arc::clone(cluster.get_runtime()),
82+
cluster.get_runtime(),
8383
session,
8484
session_builder,
8585
cluster_client_id,
@@ -89,7 +89,7 @@ impl CassConnectedSession {
8989
);
9090

9191
CassFuture::make_raw(
92-
Arc::clone(cluster.get_runtime()),
92+
cluster.get_runtime(),
9393
fut,
9494
#[cfg(cpp_integration_testing)]
9595
None,

0 commit comments

Comments
 (0)