Skip to content

Async support #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
334 changes: 325 additions & 9 deletions Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions ct.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ if [ "$TRAVIS_RUST_VERSION" == "stable" ] || [ "$TRAVIS_RUST_VERSION" == "beta"
cargo test --features pkcs12 --target $TARGET
cargo test --features pkcs12_rc2 --target $TARGET
cargo test --features dsa --target $TARGET
cargo test --test async_session --features=async-rt --target $TARGET

# If zlib is installed, test the zlib feature
if [ -n "$ZLIB_INSTALLED" ]; then
Expand Down
5 changes: 2 additions & 3 deletions mbedtls-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ quote = "1.0.9"
# * strstr/strlen/strncpy/strncmp/strcmp/snprintf
# * memmove/memcpy/memcmp/memset
# * rand/printf (used only for self tests. optionally use custom_printf)
default = ["std", "debug", "threading", "zlib", "time", "aesni", "padlock", "legacy_protocols"]
std = ["debug"] # deprecated automatic enabling of debug, can be removed on major version bump
debug = []
default = ["std", "threading", "zlib", "time", "aesni", "padlock", "legacy_protocols"]
std = []
custom_printf = []
custom_has_support = []
aes_alt = []
Expand Down
15 changes: 13 additions & 2 deletions mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@ bit-vec = { version = "0.5", optional = true }
cbc = { version = "0.1.2", optional = true }
rc2 = { version = "0.8.1", optional = true }
cfg-if = "1.0.0"
tokio = { version = "1.16.1", optional = true }

[target.x86_64-fortanix-unknown-sgx.dependencies]
rs-libc = "0.2.0"
chrono = "0.4"

[dependencies.mbedtls-sys-auto]
version = "2.25.0"
version = "2.28.0"
default-features = false
features = ["custom_printf", "trusted_cert_callback", "threading"]
path = "../mbedtls-sys"
Expand All @@ -47,6 +48,9 @@ serde_cbor = "0.6"
hex = "0.3"
matches = "0.1.8"
hyper = { version = "0.10.16", default-features = false }
async-stream = "0.3.0"
futures = "0.3"
tracing = "0.1"

[build-dependencies]
cc = "1.0"
Expand All @@ -55,7 +59,7 @@ cc = "1.0"
# Features are documented in the README
default = ["std", "aesni", "time", "padlock"]
std = ["byteorder/std", "mbedtls-sys-auto/std", "serde/std", "yasna"]
debug = ["mbedtls-sys-auto/debug"]
debug = []
no_std_deps = ["spin", "serde/alloc"]
force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"]
mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"]
Expand All @@ -68,6 +72,8 @@ dsa = ["std", "yasna", "num-bigint", "bit-vec"]
pkcs12 = ["std", "yasna"]
pkcs12_rc2 = ["pkcs12", "rc2", "cbc"]
legacy_protocols = ["mbedtls-sys-auto/legacy_protocols"]
async = ["std", "tokio","tokio/net","tokio/io-util", "tokio/macros"]
async-rt = ["async", "tokio/rt", "tokio/sync", "tokio/rt-multi-thread"]

[[example]]
name = "client"
Expand Down Expand Up @@ -100,3 +106,8 @@ required-features = ["std"]
[[test]]
name = "hyper"
required-features = ["std"]

[[test]]
name = "async_session"
path = "tests/async_session.rs"
required-features = ["async-rt"]
10 changes: 5 additions & 5 deletions mbedtls/src/pk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ define!(
// B. Verifying thread safety.
//
// 1. Calls towards the specific Pk implementation are done via function pointers.
//
//
// - Example call towards Pk:
// ../../../mbedtls-sys/vendor/library/ssl_srv.c:3707 - mbedtls_pk_decrypt( private_key, p, len, ...
// - This calls a generic function pointer via:
Expand All @@ -174,7 +174,7 @@ define!(
// - The function pointers are defined via function:
// ../../../mbedtls-sys/vendor/crypto/library/pk.c:115 - mbedtls_pk_info_from_type
// - They are as follows: mbedtls_rsa_info / mbedtls_eckey_info / mbedtls_ecdsa_info
// - These are defined in:
// - These are defined in:
// ../../../mbedtls-sys/vendor/crypto/library/pk_wrap.c:196
//
// C. Checking types one by one.
Expand Down Expand Up @@ -222,7 +222,7 @@ define!(
// mbedtls_ecp_mul_restartable: ../../../mbedtls-sys/vendor/crypto/library/ecp.c:2351
// MBEDTLS_ECP_INTERNAL_ALT is not defined. (otherwise it might not be safe depending on ecp_init/ecp_free) ../../../mbedtls-sys/build/config.rs:131
// Passes as const to: mbedtls_ecp_check_privkey / mbedtls_ecp_check_pubkey / mbedtls_ecp_get_type( grp
//
//
// - Ignored due to not defined: ecdsa_verify_rs_wrap, ecdsa_sign_rs_wrap, ecdsa_rs_alloc, ecdsa_rs_free
// (Undefined - MBEDTLS_ECP_RESTARTABLE - ../../../mbedtls-sys/build/config.rs:173)
//
Expand Down Expand Up @@ -927,7 +927,7 @@ impl Pk {
if hash.len() == 0 || sig.len() == 0 {
return Err(Error::PkBadInputData)
}

unsafe {
pk_verify(
&mut self.inner,
Expand Down Expand Up @@ -1297,7 +1297,7 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi
let mut dummy_sig = [];
assert_eq!(pk.sign(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData);
assert_eq!(pk.sign(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData);

assert_eq!(pk.sign_deterministic(digest, data, &mut dummy_sig, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData);
assert_eq!(pk.sign_deterministic(digest, &[], &mut signature, &mut crate::test_support::rand::test_rng()).unwrap_err(), Error::PkBadInputData);

Expand Down
143 changes: 143 additions & 0 deletions mbedtls/src/ssl/async_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

#![cfg(all(feature = "std", feature = "async"))]

use std::cell::Cell;
use std::ptr::null_mut;
use std::rc::Rc;
use std::task::{Context as TaskContext, Poll};


#[cfg(feature = "std")]
use std::io::{Error as IoError, Result as IoResult, ErrorKind as IoErrorKind};


#[derive(Clone)]
pub struct ErasedContext(Rc<Cell<*mut ()>>);

unsafe impl Send for ErasedContext {}

impl ErasedContext {
pub fn new() -> Self {
Self(Rc::new(Cell::new(null_mut())))
}

pub unsafe fn get(&self) -> Option<&mut TaskContext<'_>> {
let ptr = self.0.get();
if ptr.is_null() {
None
} else {
Some(&mut *(ptr as *mut _))
}
}

pub fn set(&self, cx: &mut TaskContext<'_>) {
self.0.set(cx as *mut _ as *mut ());
}

pub fn clear(&self) {
self.0.set(null_mut());
}
}

// mbedtls_ssl_write() has some weird semantics w.r.t non-blocking I/O:
//
// > When this function returns MBEDTLS_ERR_SSL_WANT_WRITE/READ, it must be
// > called later **with the same arguments**, until it returns a value greater
// > than or equal to 0. When the function returns MBEDTLS_ERR_SSL_WANT_WRITE
// > there may be some partial data in the output buffer, however this is not
// > yet sent.
//
// WriteTracker is used to ensure we pass the same data in that scenario.
//
// Reference:
// https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5
pub struct WriteTracker {
pending: Option<Box<DigestAndLen>>,
}

struct DigestAndLen {
#[cfg(debug_assertions)]
digest: [u8; 20], // SHA-1
len: usize,
}

impl WriteTracker {
fn new() -> Self {
WriteTracker {
pending: None,
}
}

#[cfg(debug_assertions)]
fn digest(buf: &[u8]) -> [u8; 20] {
use crate::hash::{Md, Type};
let mut out = [0u8; 20];
let res = Md::hash(Type::Sha1, buf, &mut out[..]);
assert_eq!(res, Ok(out.len()));
out
}

pub fn adjust_buf<'a>(&self, buf: &'a [u8]) -> IoResult<&'a [u8]> {
match self.pending.as_ref() {
None => Ok(buf),
Some(pending) => {
if pending.len <= buf.len() {
let buf = &buf[..pending.len];

// We only do this check in debug mode since it's an expensive check.
#[cfg(debug_assertions)]
if Self::digest(buf) == pending.digest {
return Ok(buf);
}

#[cfg(not(debug_assertions))]
return Ok(buf);
}
Err(IoError::new(
IoErrorKind::Other,
"mbedtls expects the same data if the previous call to poll_write() returned Poll::Pending"
))
},
}
}

pub fn post_write(&mut self, buf: &[u8], res: &Poll<IoResult<usize>>) {
match res {
&Poll::Pending => {
if self.pending.is_none() {
self.pending = Some(Box::new(DigestAndLen {
#[cfg(debug_assertions)]
digest: Self::digest(buf),
len: buf.len(),
}));
}
},
_ => {
self.pending = None;
}
}
}
}

pub struct IoAdapter<S> {
pub inner: S,
pub ecx: ErasedContext,
pub write_tracker: WriteTracker,
}

impl<S> IoAdapter<S> {
pub fn new(stream: S) -> Self {
Self {
inner: stream,
ecx: ErasedContext::new(),
write_tracker: WriteTracker::new(),
}
}
}
Loading