Skip to content

Commit 8ff1bd6

Browse files
committed
feat: add support for splitting TcpStream
1 parent 79b8a76 commit 8ff1bd6

File tree

4 files changed

+214
-38
lines changed

4 files changed

+214
-38
lines changed

madsim/src/sim/net/tcp/listener.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ impl Socket for TcpListenerSocket {
8989
write_buf: Default::default(),
9090
read_buf: Default::default(),
9191
tx,
92-
rx,
92+
rx: Mutex::new(rx),
9393
};
9494
let _ = self.tx.try_send(stream);
9595
}

madsim/src/sim/net/tcp/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,12 @@ pub type Payload = Box<dyn Any + Send + Sync>;
4848

4949
mod config;
5050
mod listener;
51+
mod split;
5152
mod stream;
5253

5354
pub use self::config::*;
5455
pub use self::listener::*;
56+
pub use self::split::*;
5557
pub use self::stream::*;
5658

5759
#[cfg(test)]

madsim/src/sim/net/tcp/split.rs

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
use crate::net::TcpStream;
2+
use bytes::BufMut;
3+
use std::{
4+
io,
5+
pin::Pin,
6+
sync::Arc,
7+
task::{Context, Poll},
8+
};
9+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10+
11+
/// Borrowed read half of a [`TcpStream`].
12+
#[derive(Debug)]
13+
pub struct ReadHalf<'a>(&'a TcpStream);
14+
15+
/// Borrowed write half of a [`TcpStream`].
16+
#[derive(Debug)]
17+
pub struct WriteHalf<'a>(&'a TcpStream);
18+
19+
pub(crate) fn split(stream: &mut TcpStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
20+
(ReadHalf(&*stream), WriteHalf(&*stream))
21+
}
22+
23+
impl ReadHalf<'_> {
24+
/// Tries to read data from the stream into the provided buffer, advancing
25+
/// the buffer's internal cursor, returning how many bytes were read.
26+
///
27+
/// Receives any pending data from the socket but does not wait for new data
28+
/// to arrive. On success, returns the number of bytes read. Because
29+
/// `try_read_buf()` is non-blocking, the buffer does not have to be stored
30+
/// by the async task and can exist entirely on the stack.
31+
pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
32+
self.0.try_read_buf(buf)
33+
}
34+
}
35+
36+
impl AsyncRead for ReadHalf<'_> {
37+
fn poll_read(
38+
self: Pin<&mut Self>,
39+
cx: &mut Context<'_>,
40+
buf: &mut ReadBuf<'_>,
41+
) -> Poll<io::Result<()>> {
42+
self.0.poll_read_priv(cx, buf)
43+
}
44+
}
45+
46+
impl AsyncWrite for WriteHalf<'_> {
47+
fn poll_write(
48+
self: Pin<&mut Self>,
49+
cx: &mut Context<'_>,
50+
buf: &[u8],
51+
) -> Poll<Result<usize, io::Error>> {
52+
self.0.poll_write_priv(cx, buf)
53+
}
54+
55+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
56+
self.0.poll_flush_priv(cx)
57+
}
58+
59+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
60+
self.0.poll_shutdown_priv(cx)
61+
}
62+
}
63+
64+
impl AsRef<TcpStream> for ReadHalf<'_> {
65+
fn as_ref(&self) -> &TcpStream {
66+
self.0
67+
}
68+
}
69+
70+
impl AsRef<TcpStream> for WriteHalf<'_> {
71+
fn as_ref(&self) -> &TcpStream {
72+
self.0
73+
}
74+
}
75+
76+
/// Owned read half of a [`TcpStream`].
77+
#[derive(Debug)]
78+
pub struct OwnedReadHalf(Arc<TcpStream>);
79+
80+
/// Owned write half of a [`TcpStream`].
81+
#[derive(Debug)]
82+
pub struct OwnedWriteHalf(Arc<TcpStream>);
83+
84+
pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
85+
let arc = Arc::new(stream);
86+
let read = OwnedReadHalf(Arc::clone(&arc));
87+
let write = OwnedWriteHalf(arc);
88+
(read, write)
89+
}
90+
91+
impl OwnedReadHalf {
92+
/// Tries to read data from the stream into the provided buffer, advancing
93+
/// the buffer's internal cursor, returning how many bytes were read.
94+
///
95+
/// Receives any pending data from the socket but does not wait for new data
96+
/// to arrive. On success, returns the number of bytes read. Because
97+
/// `try_read_buf()` is non-blocking, the buffer does not have to be stored
98+
/// by the async task and can exist entirely on the stack.
99+
pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
100+
self.0.try_read_buf(buf)
101+
}
102+
}
103+
104+
impl AsyncRead for OwnedReadHalf {
105+
fn poll_read(
106+
self: Pin<&mut Self>,
107+
cx: &mut Context<'_>,
108+
buf: &mut ReadBuf<'_>,
109+
) -> Poll<io::Result<()>> {
110+
self.0.poll_read_priv(cx, buf)
111+
}
112+
}
113+
114+
impl AsyncWrite for OwnedWriteHalf {
115+
fn poll_write(
116+
self: Pin<&mut Self>,
117+
cx: &mut Context<'_>,
118+
buf: &[u8],
119+
) -> Poll<Result<usize, io::Error>> {
120+
self.0.poll_write_priv(cx, buf)
121+
}
122+
123+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
124+
self.0.poll_flush_priv(cx)
125+
}
126+
127+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
128+
self.0.poll_shutdown_priv(cx)
129+
}
130+
}
131+
132+
impl AsRef<TcpStream> for OwnedReadHalf {
133+
fn as_ref(&self) -> &TcpStream {
134+
&self.0
135+
}
136+
}
137+
138+
impl AsRef<TcpStream> for OwnedWriteHalf {
139+
fn as_ref(&self) -> &TcpStream {
140+
&self.0
141+
}
142+
}

madsim/src/sim/net/tcp/stream.rs

Lines changed: 69 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use crate::net::{IpProtocol::Tcp, *};
1+
use crate::net::{tcp::split, IpProtocol::Tcp, *};
22
use bytes::{Buf, BufMut, BytesMut};
3+
use spin::Mutex;
34
#[cfg(unix)]
45
use std::os::unix::io::{AsRawFd, RawFd};
56
use std::{
@@ -17,10 +18,10 @@ pub struct TcpStream {
1718
pub(super) addr: SocketAddr,
1819
pub(super) peer: SocketAddr,
1920
/// Buffer write data to be flushed.
20-
pub(super) write_buf: BytesMut,
21-
pub(super) read_buf: Bytes,
21+
pub(super) write_buf: Mutex<BytesMut>,
22+
pub(super) read_buf: Mutex<Bytes>,
2223
pub(super) tx: PayloadSender,
23-
pub(super) rx: PayloadReceiver,
24+
pub(super) rx: Mutex<PayloadReceiver>,
2425
}
2526

2627
impl fmt::Debug for TcpStream {
@@ -80,7 +81,7 @@ impl TcpStream {
8081
write_buf: Default::default(),
8182
read_buf: Default::default(),
8283
tx,
83-
rx,
84+
rx: Mutex::new(rx),
8485
};
8586
Ok(stream)
8687
}
@@ -108,83 +109,114 @@ impl TcpStream {
108109
/// to arrive. On success, returns the number of bytes read. Because
109110
/// `try_read_buf()` is non-blocking, the buffer does not have to be stored
110111
/// by the async task and can exist entirely on the stack.
111-
pub fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<usize> {
112+
pub fn try_read_buf<B: BufMut>(&self, buf: &mut B) -> io::Result<usize> {
112113
// read the buffer if not empty
113-
if !self.read_buf.is_empty() {
114-
let len = self.read_buf.len().min(buf.remaining_mut());
115-
buf.put_slice(&self.read_buf[..len]);
116-
self.read_buf.advance(len);
114+
let mut read_buf = self.read_buf.lock();
115+
if !read_buf.is_empty() {
116+
let len = read_buf.len().min(buf.remaining_mut());
117+
buf.put_slice(&read_buf[..len]);
118+
read_buf.advance(len);
117119
return Ok(len);
118120
}
119121
Err(io::Error::new(
120122
io::ErrorKind::WouldBlock,
121123
"read buffer is empty",
122124
))
123125
}
124-
}
125126

126-
#[cfg(unix)]
127-
impl AsRawFd for TcpStream {
128-
fn as_raw_fd(&self) -> RawFd {
129-
todo!("TcpStream::as_raw_fd");
127+
/// Splits a `TcpStream` into a read half and a write half, which can be used
128+
/// to read and write the stream concurrently.
129+
pub fn split(&mut self) -> (split::ReadHalf<'_>, split::WriteHalf<'_>) {
130+
split::split(self)
130131
}
131-
}
132132

133-
impl AsyncRead for TcpStream {
134-
fn poll_read(
135-
mut self: Pin<&mut Self>,
133+
/// Splits a `TcpStream` into a read half and a write half, which can be
134+
/// used to read and write the stream concurrently.
135+
pub fn into_split(self) -> (split::OwnedReadHalf, split::OwnedWriteHalf) {
136+
split::split_owned(self)
137+
}
138+
139+
/// `poll_read` that takes `&self`.
140+
pub(super) fn poll_read_priv(
141+
&self,
136142
cx: &mut Context<'_>,
137143
buf: &mut ReadBuf<'_>,
138144
) -> Poll<Result<()>> {
139145
// read the buffer if not empty
140-
if !self.read_buf.is_empty() {
141-
let len = self.read_buf.len().min(buf.remaining());
142-
buf.put_slice(&self.read_buf[..len]);
143-
self.read_buf.advance(len);
146+
if self.try_read_buf(buf).is_ok() {
144147
return Poll::Ready(Ok(()));
145148
}
149+
146150
// otherwise wait on channel
147-
let poll_res = { self.rx.poll_next_unpin(cx) };
151+
let poll_res = { self.rx.lock().poll_next_unpin(cx) };
148152
match poll_res {
149153
Poll::Pending => Poll::Pending,
150154
Poll::Ready(Some(data)) => {
151-
self.read_buf = *data.downcast::<Bytes>().unwrap();
152-
self.poll_read(cx, buf)
155+
*self.read_buf.lock() = *data.downcast::<Bytes>().unwrap();
156+
self.poll_read_priv(cx, buf)
153157
}
154158
// ref: https://man7.org/linux/man-pages/man2/recv.2.html
155159
// > When a stream socket peer has performed an orderly shutdown, the
156160
// > return value will be 0 (the traditional "end-of-file" return).
157161
Poll::Ready(None) => Poll::Ready(Ok(())),
158162
}
159163
}
160-
}
161164

162-
impl AsyncWrite for TcpStream {
163-
fn poll_write(
164-
mut self: Pin<&mut Self>,
165-
_cx: &mut Context<'_>,
166-
buf: &[u8],
167-
) -> Poll<Result<usize>> {
168-
self.write_buf.extend_from_slice(buf);
165+
/// `poll_write` that takes `&self`.
166+
pub(super) fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
167+
self.write_buf.lock().extend_from_slice(buf);
169168
// TODO: simulate buffer full, partial write
170169
Poll::Ready(Ok(buf.len()))
171170
}
172171

173-
fn poll_flush(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
172+
/// `poll_flush` that takes `&self`.
173+
pub(super) fn poll_flush_priv(&self, _cx: &mut Context<'_>) -> Poll<Result<()>> {
174174
// send data
175-
let data = self.write_buf.split().freeze();
175+
let data = self.write_buf.lock().split().freeze();
176176
self.tx
177177
.send(Box::new(data))
178178
.ok_or_else(|| io::Error::new(io::ErrorKind::ConnectionReset, "connection reset"))?;
179179
Poll::Ready(Ok(()))
180180
}
181181

182-
fn poll_shutdown(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<()>> {
182+
/// `poll_shutdown` that takes `&self`.
183+
pub(super) fn poll_shutdown_priv(&self, _: &mut Context<'_>) -> Poll<Result<()>> {
183184
// TODO: simulate shutdown
184185
Poll::Ready(Ok(()))
185186
}
186187
}
187188

189+
#[cfg(unix)]
190+
impl AsRawFd for TcpStream {
191+
fn as_raw_fd(&self) -> RawFd {
192+
todo!("TcpStream::as_raw_fd");
193+
}
194+
}
195+
196+
impl AsyncRead for TcpStream {
197+
fn poll_read(
198+
self: Pin<&mut Self>,
199+
cx: &mut Context<'_>,
200+
buf: &mut ReadBuf<'_>,
201+
) -> Poll<Result<()>> {
202+
self.poll_read_priv(cx, buf)
203+
}
204+
}
205+
206+
impl AsyncWrite for TcpStream {
207+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
208+
self.poll_write_priv(cx, buf)
209+
}
210+
211+
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
212+
self.poll_flush_priv(cx)
213+
}
214+
215+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<()>> {
216+
self.poll_shutdown_priv(cx)
217+
}
218+
}
219+
188220
/// Socket registered in the [`Network`].
189221
struct TcpStreamSocket;
190222

0 commit comments

Comments
 (0)