Skip to content
Open
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
artifacts/
*.~undo-tree~
2 changes: 2 additions & 0 deletions wolfssl-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ fn build_wolfssl(wolfssl_src: &Path) -> PathBuf {
.enable("dtls-frag-ch", None)
// Enable setting the D/TLS MTU size
.enable("dtls-mtu", None)
// Enable pre-shared keys
.enable("psk", None)
// Enable Secure Renegotiation
.enable("secure-renegotiation", None)
// Enable single threaded mode
Expand Down
233 changes: 229 additions & 4 deletions wolfssl/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,21 @@ use crate::{
ssl::{Session, SessionConfig},
CurveGroup, Method, NewSessionError, RootCertificate, Secret, SslVerifyMode,
};
use std::os::raw::c_int;
use std::ptr::NonNull;
use std::{
ffi::{c_void, CStr, CString},
fmt::Debug,
os::raw::{c_char, c_int, c_uint},
ptr::NonNull,
sync::Arc,
};
use thiserror::Error;

/// Produces a [`Context`] once built.
#[derive(Debug)]
pub struct ContextBuilder {
ctx: NonNull<wolfssl_sys::WOLFSSL_CTX>,
method: Method,
pre_shared_key_callbacks: Option<Box<dyn PreSharedKeyCallbacks>>,
}

/// Error creating a [`ContextBuilder`] object.
Expand Down Expand Up @@ -49,7 +55,11 @@ impl ContextBuilder {
let ctx = unsafe { wolfssl_sys::wolfSSL_CTX_new(method_fn.as_ptr()) };
let ctx = NonNull::new(ctx).ok_or(NewContextBuilderError::CreateFailed)?;

Ok(Self { ctx, method })
Ok(Self {
ctx,
method,
pre_shared_key_callbacks: None,
})
}

/// When `cond` is True call fallible `func` on `Self`
Expand Down Expand Up @@ -393,6 +403,152 @@ impl ContextBuilder {
}
}

unsafe extern "C" fn psk_server_callback(
ssl: *mut wolfssl_sys::WOLFSSL,
identity_ptr: *const c_char,
key_output_ptr: *mut u8,
max_key_length_c_uint: c_uint,
) -> c_uint {
debug_assert!(!ssl.is_null());
debug_assert!(!identity_ptr.is_null()); // this is never null, it points to an array in an `Arrays` struct
debug_assert!(!key_output_ptr.is_null());

// SAFETY: identity_ptr is in fact a C string
let identity: &CStr = unsafe { CStr::from_ptr(identity_ptr) };
let max_key_length: usize = max_key_length_c_uint.try_into().unwrap();

// SAFETY: `wolfSSL_get_psk_callback_ctx` is undocumented, but the implementation simply
// gets a field out of the WOLFSSL object.
let stored_cbs_ptr_ptr: *const c_void =
unsafe { wolfssl_sys::wolfSSL_get_psk_callback_ctx(ssl) };
// SAFETY: This is written in `Session::new_from_wolfssl_pointer` as a pointer to the
// contents of an Box, so should have stable address. The Box is stored at least until the
// end of the session and hence should be alive.
#[allow(clippy::borrowed_box)]
let stored_cbs: &Box<dyn PreSharedKeyCallbacks> =
unsafe { &*(stored_cbs_ptr_ptr as *const Box<dyn PreSharedKeyCallbacks>) };

let maybe_key = stored_cbs.psk_server_callback(identity, max_key_length);
match maybe_key {
Some(key) => {
assert!(
key.len() <= max_key_length,
"Key length {} returned by server callback was longer than maximum {}",
key.len(),
max_key_length
);
// SAFETY: we've verified that the vec length is <= max_key_length, so we won't overrun
// the buffer provided to us.
unsafe { std::ptr::copy(key.as_ptr(), key_output_ptr, key.len()) };
key.len().try_into().unwrap()
}
None => 0,
}
}

unsafe extern "C" fn psk_client_callback(
ssl: *mut wolfssl_sys::WOLFSSL,
_hint: *const c_char,
identity_output: *mut c_char,
max_identity_length_c_uint: c_uint,
key_output: *mut u8,
max_key_length_c_uint: c_uint,
) -> c_uint {
debug_assert!(!ssl.is_null());
debug_assert!(!identity_output.is_null());
debug_assert!(!key_output.is_null());

let max_identity_length: usize = max_identity_length_c_uint.try_into().unwrap();
let max_key_length: usize = max_key_length_c_uint.try_into().unwrap();

// SAFETY: See `psk_server_callback`
let stored_cbs_ptr_ptr: *const c_void =
unsafe { wolfssl_sys::wolfSSL_get_psk_callback_ctx(ssl) };
// SAFETY: See `psk_server_callback`
#[allow(clippy::borrowed_box)]
let stored_cbs: &Box<dyn PreSharedKeyCallbacks> =
unsafe { &*(stored_cbs_ptr_ptr as *const Box<dyn PreSharedKeyCallbacks>) };

let maybe_result = stored_cbs.psk_client_callback(max_identity_length, max_key_length);
match maybe_result {
Some(PreSharedKeyClientCallbackResult { identity, key }) => {
assert!(
identity.count_bytes() <= max_identity_length,
"Identity length {} was not less than maximum {}",
identity.count_bytes(),
max_identity_length
);
assert!(
key.len() <= max_key_length,
"Key length {} was not less than maximum {}",
key.len(),
max_key_length
);

// SAFETY: See `psk_server_callback`.
unsafe { std::ptr::copy(key.as_ptr(), key_output, key.len()) };
// SAFETY: See immediately above. +1 to account for nul terminator.
// `max_identity_length` is not including the nul terminator (the definition of the
// `client_identity` field in the `Arrays` struct in wolfssl `internal.h` has length
// `MAX_PSK_ID_LEN + NULL_TERM_LEN`, and `MAX_PSK_ID_LEN` is what is passed as the
// `max_identity_length`)
unsafe {
std::ptr::copy(
identity.as_ptr(),
identity_output,
identity.count_bytes() + 1,
)
};

key.len().try_into().unwrap()
}
None => 0,
}
}

/// Use a fixed pre-shared key for authentication
///
/// See also: [with_pre_shared_key_callbacks]
pub fn with_pre_shared_key(self, psk: &[u8]) -> Self {
self.with_pre_shared_key_callbacks(Box::new(FixedPskCallbacks::new(psk)))
}

/// Use pre-shared key callbacks for authentication
///
/// Install custom client and server callbacks for pre-shared-key authentication. Calls either
/// `wolfSSL_CTX_set_psk_server_callback` or `wolfSSL_CTX_set_psk_client_callback` appropriately
/// using fixed callbacks provided by wolfssl-rs. Later, during session constrtuction, calls
/// `wolfSSL_set_psk_callback_ctx` to point to make the user-provided safe callbacks accessible
/// in the fixed callback. The fixed callback does the unsafe work and delegates the interesting
/// logic to the safe user-provided callback.
pub fn with_pre_shared_key_callbacks(self, callbacks: Box<dyn PreSharedKeyCallbacks>) -> Self {
if self.method.is_server() {
// SAFETY: `wolfSSL_CTX_set_psk_server_callback` isn't properly documented. It seems the
// only requirement is that the context is valid and the callback will be alive
// throughout the lifetime of the context and any created sessions; our callbacks are
// &'static.
unsafe {
wolfssl_sys::wolfSSL_CTX_set_psk_server_callback(
self.ctx.as_ptr(),
Some(Self::psk_server_callback),
);
};
} else {
// SAFETY: See above.
unsafe {
wolfssl_sys::wolfSSL_CTX_set_psk_client_callback(
self.ctx.as_ptr(),
Some(Self::psk_client_callback),
);
};
};

Self {
pre_shared_key_callbacks: Some(callbacks),
..self
}
}

/// Wraps `wolfSSL_CTX_UseSecureRenegotiation`
///
/// NOTE: No official documentation available for this api from wolfssl
Expand Down Expand Up @@ -438,6 +594,7 @@ impl ContextBuilder {
Context {
method: self.method,
ctx: ContextPointer(self.ctx),
pre_shared_key_callbacks: self.pre_shared_key_callbacks.map(Arc::new),
}
}
}
Expand Down Expand Up @@ -513,6 +670,7 @@ unsafe impl Send for WolfsslPointer {}
pub struct Context {
method: Method,
ctx: ContextPointer,
pre_shared_key_callbacks: Option<Arc<Box<dyn PreSharedKeyCallbacks>>>,
}

impl Context {
Expand All @@ -534,7 +692,7 @@ impl Context {

let ssl = WolfsslPointer(NonNull::new(ptr).ok_or(NewSessionError::CreateFailed)?);

Session::new_from_wolfssl_pointer(ssl, config)
Session::new_from_wolfssl_pointer(ssl, config, self.pre_shared_key_callbacks.clone())
}
}

Expand All @@ -555,6 +713,73 @@ impl Drop for Context {
}
}

/// Returned from the client callback in [PreSharedKeyCallbacks]
pub struct PreSharedKeyClientCallbackResult {
/// Should be an empty string if you don't need multiple identities. Else, an arbitrary string
/// that the server will be able to read to determine which PSK to use.
pub identity: CString,
/// The pre-shared key itself.
pub key: Vec<u8>,
}

/// Callbacks that are used to provide a pre-shared key to wolfSSL.
pub trait PreSharedKeyCallbacks: Debug {
/// Called on the client before starting the connection.
///
/// The installed wolfSSL callback will return 0 if None is returned from the Rust callback,
/// which means "fail". The wolfSSL docs are unclear what happens when the callback fails in
/// this way.
fn psk_client_callback(
&self,
max_identity_length: usize,
max_key_length: usize,
) -> Option<PreSharedKeyClientCallbackResult>;

/// Called on the server after receiving the client hello.
///
/// Receives the identity set in the client callback. Return the key, or None on failure.
fn psk_server_callback(&self, identity: &CStr, max_key_length: usize) -> Option<Vec<u8>>;
}

/// An implementation of PreSharedKeyCallbacks that uses a fixed buffer as the pre-shared key, which
/// is the most common usecase for pre shared keys.
#[derive(Debug)]
struct FixedPskCallbacks {
key: Vec<u8>,
}

impl FixedPskCallbacks {
/// Construct a FixedPskCallbacks object that will always use the given key, ignoring identity.
fn new<T: Into<Vec<u8>>>(key: T) -> FixedPskCallbacks {
FixedPskCallbacks { key: key.into() }
}
}

impl PreSharedKeyCallbacks for FixedPskCallbacks {
fn psk_client_callback(
&self,
_max_identity_length: usize,
max_key_length: usize,
) -> Option<PreSharedKeyClientCallbackResult> {
if self.key.len() > max_key_length {
return None;
}

Some(PreSharedKeyClientCallbackResult {
identity: c"".into(),
key: self.key.clone(),
})
}

fn psk_server_callback(&self, _identity: &CStr, max_key_length: usize) -> Option<Vec<u8>> {
if self.key.len() > max_key_length {
return None;
}

Some(self.key.clone())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
18 changes: 18 additions & 0 deletions wolfssl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,24 @@ impl Method {

NonNull::new(ptr)
}

/// Returns true if this method is a server method
fn is_server(self) -> bool {
match self {
Self::DtlsClient => false,
Self::DtlsClientV1_2 => false,
Self::DtlsClientV1_3 => false,
Self::TlsClient => false,
Self::TlsClientV1_2 => false,
Self::TlsClientV1_3 => false,
Self::DtlsServer => true,
Self::DtlsServerV1_2 => true,
Self::DtlsServerV1_3 => true,
Self::TlsServer => true,
Self::TlsServerV1_2 => true,
Self::TlsServerV1_3 => true,
}
}
}

/// Corresponds to the various defined `WOLFSSL_*` curves
Expand Down
Loading