|
| 1 | +/* Copyright (c) Fortanix, Inc. |
| 2 | + * |
| 3 | + * Licensed under the GNU General Public License, version 2 <LICENSE-GPL or |
| 4 | + * https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version |
| 5 | + * 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your |
| 6 | + * option. This file may not be copied, modified, or distributed except |
| 7 | + * according to those terms. */ |
| 8 | + |
| 9 | +#![cfg(all(feature = "std", feature = "async"))] |
| 10 | + |
| 11 | +use std::cell::Cell; |
| 12 | +use std::ptr::null_mut; |
| 13 | +use std::rc::Rc; |
| 14 | +use std::task::{Context as TaskContext, Poll}; |
| 15 | + |
| 16 | + |
| 17 | +#[cfg(not(feature = "std"))] |
| 18 | +use core_io::{Error as IoError, Result as IoResult, ErrorKind as IoErrorKind}; |
| 19 | +#[cfg(feature = "std")] |
| 20 | +use std::io::{Error as IoError, Result as IoResult, ErrorKind as IoErrorKind}; |
| 21 | + |
| 22 | + |
| 23 | +#[derive(Clone)] |
| 24 | +pub struct ErasedContext(Rc<Cell<*mut ()>>); |
| 25 | + |
| 26 | +unsafe impl Send for ErasedContext {} |
| 27 | + |
| 28 | +impl ErasedContext { |
| 29 | + pub fn new() -> Self { |
| 30 | + Self(Rc::new(Cell::new(null_mut()))) |
| 31 | + } |
| 32 | + |
| 33 | + pub unsafe fn get(&self) -> Option<&mut TaskContext<'_>> { |
| 34 | + let ptr = self.0.get(); |
| 35 | + if ptr.is_null() { |
| 36 | + None |
| 37 | + } else { |
| 38 | + Some(&mut *(ptr as *mut _)) |
| 39 | + } |
| 40 | + } |
| 41 | + |
| 42 | + pub fn set(&self, cx: &mut TaskContext<'_>) { |
| 43 | + self.0.set(cx as *mut _ as *mut ()); |
| 44 | + } |
| 45 | + |
| 46 | + pub fn clear(&self) { |
| 47 | + self.0.set(null_mut()); |
| 48 | + } |
| 49 | +} |
| 50 | + |
| 51 | +// mbedtls_ssl_write() has some weird semantics w.r.t non-blocking I/O: |
| 52 | +// |
| 53 | +// > When this function returns MBEDTLS_ERR_SSL_WANT_WRITE/READ, it must be |
| 54 | +// > called later **with the same arguments**, until it returns a value greater |
| 55 | +// > than or equal to 0. When the function returns MBEDTLS_ERR_SSL_WANT_WRITE |
| 56 | +// > there may be some partial data in the output buffer, however this is not |
| 57 | +// > yet sent. |
| 58 | +// |
| 59 | +// WriteTracker is used to ensure we pass the same data in that scenario. |
| 60 | +// |
| 61 | +// Reference: |
| 62 | +// https://tls.mbed.org/api/ssl_8h.html#a5bbda87d484de82df730758b475f32e5 |
| 63 | +pub struct WriteTracker { |
| 64 | + pending: Option<Box<DigestAndLen>>, |
| 65 | +} |
| 66 | + |
| 67 | +struct DigestAndLen { |
| 68 | + #[cfg(debug_assertions)] |
| 69 | + digest: [u8; 20], // SHA-1 |
| 70 | + len: usize, |
| 71 | +} |
| 72 | + |
| 73 | +impl WriteTracker { |
| 74 | + fn new() -> Self { |
| 75 | + WriteTracker { |
| 76 | + pending: None, |
| 77 | + } |
| 78 | + } |
| 79 | + |
| 80 | + #[cfg(debug_assertions)] |
| 81 | + fn digest(buf: &[u8]) -> [u8; 20] { |
| 82 | + use crate::hash::{Md, Type}; |
| 83 | + let mut out = [0u8; 20]; |
| 84 | + let res = Md::hash(Type::Sha1, buf, &mut out[..]); |
| 85 | + assert_eq!(res, Ok(out.len())); |
| 86 | + out |
| 87 | + } |
| 88 | + |
| 89 | + pub fn adjust_buf<'a>(&self, buf: &'a [u8]) -> IoResult<&'a [u8]> { |
| 90 | + match self.pending.as_ref() { |
| 91 | + None => Ok(buf), |
| 92 | + Some(pending) => { |
| 93 | + if pending.len <= buf.len() { |
| 94 | + let buf = &buf[..pending.len]; |
| 95 | + |
| 96 | + // We only do this check in debug mode since it's an expensive check. |
| 97 | + #[cfg(debug_assertions)] |
| 98 | + if Self::digest(buf) == pending.digest { |
| 99 | + return Ok(buf); |
| 100 | + } |
| 101 | + |
| 102 | + #[cfg(not(debug_assertions))] |
| 103 | + return Ok(buf); |
| 104 | + } |
| 105 | + Err(IoError::new( |
| 106 | + IoErrorKind::Other, |
| 107 | + "mbedtls expects the same data if the previous call to poll_write() returned Poll::Pending" |
| 108 | + )) |
| 109 | + }, |
| 110 | + } |
| 111 | + } |
| 112 | + |
| 113 | + pub fn post_write(&mut self, buf: &[u8], res: &Poll<IoResult<usize>>) { |
| 114 | + match res { |
| 115 | + &Poll::Pending => { |
| 116 | + if self.pending.is_none() { |
| 117 | + self.pending = Some(Box::new(DigestAndLen { |
| 118 | + #[cfg(debug_assertions)] |
| 119 | + digest: Self::digest(buf), |
| 120 | + len: buf.len(), |
| 121 | + })); |
| 122 | + } |
| 123 | + }, |
| 124 | + _ => { |
| 125 | + self.pending = None; |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | +} |
| 130 | + |
| 131 | +pub struct IoAdapter<S> { |
| 132 | + pub inner: S, |
| 133 | + pub ecx: ErasedContext, |
| 134 | + pub write_tracker: WriteTracker, |
| 135 | +} |
| 136 | + |
| 137 | +impl<S> IoAdapter<S> { |
| 138 | + pub fn new(stream: S) -> Self { |
| 139 | + Self { |
| 140 | + inner: stream, |
| 141 | + ecx: ErasedContext::new(), |
| 142 | + write_tracker: WriteTracker::new(), |
| 143 | + } |
| 144 | + } |
| 145 | +} |
0 commit comments