Skip to content

Commit f837aeb

Browse files
committed
Introduce async callbacks
We introduce tokio_boring::SslContextBuilderExt, with 2 methods: * set_async_select_certificate_callback * set_async_private_key_method
1 parent ca47aaf commit f837aeb

8 files changed

Lines changed: 580 additions & 3 deletions

File tree

boring/src/ssl/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,9 @@ pub struct SelectCertError(ffi::ssl_select_cert_result_t);
482482
impl SelectCertError {
483483
/// A fatal error occured and the handshake should be terminated.
484484
pub const ERROR: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_error);
485+
486+
/// The operation could not be completed and should be retried later.
487+
pub const RETRY: Self = Self(ffi::ssl_select_cert_result_t::ssl_select_cert_retry);
485488
}
486489

487490
/// Extension types, to be used with `ClientHello::get_extension`.
@@ -3280,6 +3283,11 @@ impl<S> MidHandshakeSslStream<S> {
32803283
self.stream.ssl()
32813284
}
32823285

3286+
/// Returns a mutable reference to the `Ssl` of the stream.
3287+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3288+
self.stream.ssl_mut()
3289+
}
3290+
32833291
/// Returns the underlying error which interrupted this handshake.
32843292
pub fn error(&self) -> &Error {
32853293
&self.error
@@ -3585,6 +3593,11 @@ impl<S> SslStream<S> {
35853593
pub fn ssl(&self) -> &SslRef {
35863594
&self.ssl
35873595
}
3596+
3597+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
3598+
pub fn ssl_mut(&mut self) -> &mut SslRef {
3599+
&mut self.ssl
3600+
}
35883601
}
35893602

35903603
impl<S: Read + Write> Read for SslStream<S> {

boring/src/ssl/test/private_key_method.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,8 @@ fn test_sign_retry_complete_failure() {
189189
ErrorCode::WANT_PRIVATE_KEY_OPERATION
190190
);
191191

192-
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err() else {
192+
let HandshakeError::WouldBlock(mid_handshake) = mid_handshake.handshake().unwrap_err()
193+
else {
193194
panic!("should be WouldBlock");
194195
};
195196

tokio-boring/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ no-patches = ["boring/no-patches"]
3939
[dependencies]
4040
boring = { workspace = true }
4141
boring-sys = { workspace = true }
42+
once_cell = { workspace = true }
4243
tokio = { workspace = true }
4344

4445
[dev-dependencies]
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
use boring::ex_data::Index;
2+
use boring::ssl::{self, ClientHello, PrivateKeyMethod, Ssl, SslContextBuilder};
3+
use once_cell::sync::Lazy;
4+
use std::future::Future;
5+
use std::pin::Pin;
6+
use std::task::{ready, Context, Poll, Waker};
7+
8+
/// The type of futures to pass to [`SslContextBuilderExt::set_async_select_certificate_callback`].
9+
pub type BoxSelectCertFuture = ExDataFuture<Result<BoxSelectCertFinish, AsyncSelectCertError>>;
10+
11+
/// The type of callbacks returned by [`BoxSelectCertFuture`] methods.
12+
pub type BoxSelectCertFinish = Box<dyn FnOnce(ClientHello<'_>) -> Result<(), AsyncSelectCertError>>;
13+
14+
/// The type of futures returned by [`AsyncPrivateKeyMethod`] methods.
15+
pub type BoxPrivateKeyMethodFuture =
16+
ExDataFuture<Result<BoxPrivateKeyMethodFinish, AsyncPrivateKeyMethodError>>;
17+
18+
/// The type of callbacks returned by [`BoxPrivateKeyMethodFuture`].
19+
pub type BoxPrivateKeyMethodFinish =
20+
Box<dyn FnOnce(&mut ssl::SslRef, &mut [u8]) -> Result<usize, AsyncPrivateKeyMethodError>>;
21+
22+
/// Convenience alias for futures stored in [`Ssl`] ex data by [`SslContextBuilderExt`] methods.
23+
///
24+
/// Public for documentation purposes.
25+
pub type ExDataFuture<T> = Pin<Box<dyn Future<Output = T> + Send + Sync>>;
26+
27+
pub(crate) static TASK_WAKER_INDEX: Lazy<Index<Ssl, Option<Waker>>> =
28+
Lazy::new(|| Ssl::new_ex_index().unwrap());
29+
pub(crate) static SELECT_CERT_FUTURE_INDEX: Lazy<Index<Ssl, BoxSelectCertFuture>> =
30+
Lazy::new(|| Ssl::new_ex_index().unwrap());
31+
pub(crate) static SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX: Lazy<
32+
Index<Ssl, BoxPrivateKeyMethodFuture>,
33+
> = Lazy::new(|| Ssl::new_ex_index().unwrap());
34+
35+
/// Extensions to [`SslContextBuilder`].
36+
///
37+
/// This trait provides additional methods to use async callbacks with boring.
38+
pub trait SslContextBuilderExt: private::Sealed {
39+
/// Sets a callback that is called before most [`ClientHello`] processing
40+
/// and before the decision whether to resume a session is made. The
41+
/// callback may inspect the [`ClientHello`] and configure the connection.
42+
///
43+
/// This method uses a function that returns a future whose output is
44+
/// itself a closure that will be passed [`ClientHello`] to configure
45+
/// the connection based on the computations done in the future.
46+
///
47+
/// See [`SslContextBuilder::set_select_certificate_callback`] for the sync
48+
/// setter of this callback.
49+
fn set_async_select_certificate_callback<F>(&mut self, callback: F)
50+
where
51+
F: Fn(&mut ClientHello<'_>) -> Result<BoxSelectCertFuture, AsyncSelectCertError>
52+
+ Send
53+
+ Sync
54+
+ 'static;
55+
56+
/// Configures a custom private key method on the context.
57+
///
58+
/// See [`AsyncPrivateKeyMethod`] for more details.
59+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod);
60+
}
61+
62+
impl SslContextBuilderExt for SslContextBuilder {
63+
fn set_async_select_certificate_callback<F>(&mut self, callback: F)
64+
where
65+
F: Fn(&mut ClientHello<'_>) -> Result<BoxSelectCertFuture, AsyncSelectCertError>
66+
+ Send
67+
+ Sync
68+
+ 'static,
69+
{
70+
self.set_select_certificate_callback(move |mut client_hello| {
71+
let fut_poll_result = with_ex_data_future(
72+
&mut client_hello,
73+
*SELECT_CERT_FUTURE_INDEX,
74+
ClientHello::ssl_mut,
75+
&callback,
76+
);
77+
78+
let fut_result = match fut_poll_result {
79+
Poll::Ready(fut_result) => fut_result,
80+
Poll::Pending => return Err(ssl::SelectCertError::RETRY),
81+
};
82+
83+
let finish = fut_result.or(Err(ssl::SelectCertError::ERROR))?;
84+
85+
finish(client_hello).or(Err(ssl::SelectCertError::ERROR))
86+
})
87+
}
88+
89+
fn set_async_private_key_method(&mut self, method: impl AsyncPrivateKeyMethod) {
90+
self.set_private_key_method(AsyncPrivateKeyMethodBridge(Box::new(method)));
91+
}
92+
}
93+
94+
/// A fatal error to be returned from async select certificate callbacks.
95+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
96+
pub struct AsyncSelectCertError;
97+
98+
/// Describes async private key hooks. This is used to off-load signing
99+
/// operations to a custom, potentially asynchronous, backend. Metadata about the
100+
/// key such as the type and size are parsed out of the certificate.
101+
///
102+
/// See [`PrivateKeyMethod`] for the sync version of those hooks.
103+
///
104+
/// [`ssl_private_key_method_st`]: https://commondatastorage.googleapis.com/chromium-boringssl-docs/ssl.h.html#ssl_private_key_method_st
105+
pub trait AsyncPrivateKeyMethod: Send + Sync + 'static {
106+
/// Signs the message `input` using the specified signature algorithm.
107+
///
108+
/// This method uses a function that returns a future whose output is
109+
/// itself a closure that will be passed `ssl` and `output`
110+
/// to finish writing the signature.
111+
///
112+
/// See [`PrivateKeyMethod::sign`] for the sync version of this method.
113+
fn sign(
114+
&self,
115+
ssl: &mut ssl::SslRef,
116+
input: &[u8],
117+
signature_algorithm: ssl::SslSignatureAlgorithm,
118+
output: &mut [u8],
119+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
120+
121+
/// Decrypts `input`.
122+
///
123+
/// This method uses a function that returns a future whose output is
124+
/// itself a closure that will be passed `ssl` and `output`
125+
/// to finish decrypting the input.
126+
///
127+
/// See [`PrivateKeyMethod::decrypt`] for the sync version of this method.
128+
fn decrypt(
129+
&self,
130+
ssl: &mut ssl::SslRef,
131+
input: &[u8],
132+
output: &mut [u8],
133+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>;
134+
}
135+
136+
/// A fatal error to be returned from async private key methods.
137+
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
138+
pub struct AsyncPrivateKeyMethodError;
139+
140+
struct AsyncPrivateKeyMethodBridge(Box<dyn AsyncPrivateKeyMethod>);
141+
142+
impl PrivateKeyMethod for AsyncPrivateKeyMethodBridge {
143+
fn sign(
144+
&self,
145+
ssl: &mut ssl::SslRef,
146+
input: &[u8],
147+
signature_algorithm: ssl::SslSignatureAlgorithm,
148+
output: &mut [u8],
149+
) -> Result<usize, ssl::PrivateKeyMethodError> {
150+
with_private_key_method(ssl, output, |ssl, output| {
151+
<dyn AsyncPrivateKeyMethod>::sign(&*self.0, ssl, input, signature_algorithm, output)
152+
})
153+
}
154+
155+
fn decrypt(
156+
&self,
157+
ssl: &mut ssl::SslRef,
158+
input: &[u8],
159+
output: &mut [u8],
160+
) -> Result<usize, ssl::PrivateKeyMethodError> {
161+
with_private_key_method(ssl, output, |ssl, output| {
162+
<dyn AsyncPrivateKeyMethod>::decrypt(&*self.0, ssl, input, output)
163+
})
164+
}
165+
166+
fn complete(
167+
&self,
168+
ssl: &mut ssl::SslRef,
169+
output: &mut [u8],
170+
) -> Result<usize, ssl::PrivateKeyMethodError> {
171+
with_private_key_method(ssl, output, |_, _| {
172+
// This should never be reached, if it does, that's a bug on boring's side,
173+
// which called `complete` without having been returned to with a pending
174+
// future from `sign` or `decrypt`.
175+
176+
if cfg!(debug_assertions) {
177+
panic!("BUG: boring called complete without a pending operation");
178+
}
179+
180+
Err(AsyncPrivateKeyMethodError)
181+
})
182+
}
183+
}
184+
185+
/// Creates and drives a private key method future.
186+
///
187+
/// This is a convenience function for the three methods of impl `PrivateKeyMethod``
188+
/// for `dyn AsyncPrivateKeyMethod`. It relies on [`with_ex_data_future`] to
189+
/// drive the future and then immediately calls the final [`BoxPrivateKeyMethodFinish`]
190+
/// when the future is ready.
191+
fn with_private_key_method(
192+
ssl: &mut ssl::SslRef,
193+
output: &mut [u8],
194+
create_fut: impl FnOnce(
195+
&mut ssl::SslRef,
196+
&mut [u8],
197+
) -> Result<BoxPrivateKeyMethodFuture, AsyncPrivateKeyMethodError>,
198+
) -> Result<usize, ssl::PrivateKeyMethodError> {
199+
let fut_poll_result = with_ex_data_future(
200+
ssl,
201+
*SELECT_PRIVATE_KEY_METHOD_FUTURE_INDEX,
202+
|ssl| ssl,
203+
|ssl| create_fut(ssl, output),
204+
);
205+
206+
let fut_result = match fut_poll_result {
207+
Poll::Ready(fut_result) => fut_result,
208+
Poll::Pending => return Err(ssl::PrivateKeyMethodError::RETRY),
209+
};
210+
211+
let finish = fut_result.or(Err(ssl::PrivateKeyMethodError::FAILURE))?;
212+
213+
finish(ssl, output).or(Err(ssl::PrivateKeyMethodError::FAILURE))
214+
}
215+
216+
/// Creates and drives a future stored in `ssl_handle`'s `Ssl` at ex data index `index`.
217+
///
218+
/// This function won't even bother storing the future in `index` if the future
219+
/// created by `create_fut` returns `Poll::Ready(_)` on the first poll call.
220+
fn with_ex_data_future<H, T, E>(
221+
ssl_handle: &mut H,
222+
index: Index<ssl::Ssl, ExDataFuture<Result<T, E>>>,
223+
get_ssl_mut: impl Fn(&mut H) -> &mut ssl::SslRef,
224+
create_fut: impl FnOnce(&mut H) -> Result<ExDataFuture<Result<T, E>>, E>,
225+
) -> Poll<Result<T, E>> {
226+
let ssl = get_ssl_mut(ssl_handle);
227+
let waker = ssl
228+
.ex_data(*TASK_WAKER_INDEX)
229+
.cloned()
230+
.flatten()
231+
.expect("task waker should be set");
232+
233+
let mut ctx = Context::from_waker(&waker);
234+
235+
match ssl.ex_data_mut(index) {
236+
Some(fut) => {
237+
let fut_result = ready!(fut.as_mut().poll(&mut ctx));
238+
239+
// NOTE(nox): For memory usage concerns, maybe we should implement
240+
// a way to remove the stored future from the `Ssl` value here?
241+
242+
Poll::Ready(fut_result)
243+
}
244+
None => {
245+
let mut fut = create_fut(ssl_handle)?;
246+
247+
match fut.as_mut().poll(&mut ctx) {
248+
Poll::Ready(fut_result) => Poll::Ready(fut_result),
249+
Poll::Pending => {
250+
get_ssl_mut(ssl_handle).set_ex_data(index, fut);
251+
252+
Poll::Pending
253+
}
254+
}
255+
}
256+
}
257+
}
258+
259+
mod private {
260+
pub trait Sealed {}
261+
}
262+
263+
impl private::Sealed for SslContextBuilder {}

tokio-boring/src/bridge.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//! Bridge between sync IO traits and async tokio IO traits.
2-
32
use std::fmt;
43
use std::io;
54
use std::pin::Pin;
@@ -35,7 +34,7 @@ impl<S> AsyncStreamBridge<S> {
3534
F: FnOnce(&mut Context<'_>, Pin<&mut S>) -> R,
3635
{
3736
let mut ctx =
38-
Context::from_waker(self.waker.as_ref().expect("missing task context pointer"));
37+
Context::from_waker(self.waker.as_ref().expect("BUG: missing waker in bridge"));
3938

4039
f(&mut ctx, Pin::new(&mut self.stream))
4140
}

tokio-boring/src/lib.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,15 @@ use std::pin::Pin;
2727
use std::task::{Context, Poll};
2828
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
2929

30+
mod async_callbacks;
3031
mod bridge;
3132

33+
use self::async_callbacks::TASK_WAKER_INDEX;
34+
pub use self::async_callbacks::{
35+
AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError,
36+
BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish, BoxSelectCertFuture,
37+
ExDataFuture, SslContextBuilderExt,
38+
};
3239
use self::bridge::AsyncStreamBridge;
3340

3441
/// Asynchronously performs a client-side TLS handshake over the provided stream.
@@ -90,6 +97,11 @@ impl<S> SslStream<S> {
9097
self.0.ssl()
9198
}
9299

100+
/// Returns a mutable reference to the `Ssl` object associated with this stream.
101+
pub fn ssl_mut(&mut self) -> &mut SslRef {
102+
self.0.ssl_mut()
103+
}
104+
93105
/// Returns a shared reference to the underlying stream.
94106
pub fn get_ref(&self) -> &S {
95107
&self.0.get_ref().stream
@@ -285,15 +297,20 @@ where
285297
let mut mid_handshake = self.0.take().expect("future polled after completion");
286298

287299
mid_handshake.get_mut().set_waker(Some(ctx));
300+
mid_handshake
301+
.ssl_mut()
302+
.set_ex_data(*TASK_WAKER_INDEX, Some(ctx.waker().clone()));
288303

289304
match mid_handshake.handshake() {
290305
Ok(mut stream) => {
291306
stream.get_mut().set_waker(None);
307+
stream.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);
292308

293309
Poll::Ready(Ok(SslStream(stream)))
294310
}
295311
Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
296312
mid_handshake.get_mut().set_waker(None);
313+
mid_handshake.ssl_mut().set_ex_data(*TASK_WAKER_INDEX, None);
297314

298315
self.0 = Some(mid_handshake);
299316

0 commit comments

Comments
 (0)