Skip to content

Commit e8cb48f

Browse files
Quentin PerezmartinabeledaQuentin PerezLucioFranco
authored
Add zstd compression support (#1532)
* Implement zstd compression * Parametrize compression tests * add tests for accepting multiple encodings * add some missing feature cfg for zstd * make as_str only crate public * make into_accept_encoding_header_value handle all combinations * make decompress implementation consistent * use zstd::stream::read::Encoder * use default compression level for zstd * fix rebase * fix CI issue --------- Co-authored-by: martinabeleda <[email protected]> Co-authored-by: Quentin Perez <[email protected]> Co-authored-by: Lucio Franco <[email protected]>
1 parent 53267a3 commit e8cb48f

File tree

11 files changed

+560
-125
lines changed

11 files changed

+560
-125
lines changed

tests/compression/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@ bytes = "1"
1111
http = "0.2"
1212
http-body = "0.4"
1313
hyper = "0.14.3"
14+
paste = "1.0.12"
1415
pin-project = "1.0"
1516
prost = "0.12"
1617
tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]}
1718
tokio-stream = "0.1"
18-
tonic = {path = "../../tonic", features = ["gzip"]}
19+
tonic = {path = "../../tonic", features = ["gzip", "zstd"]}
1920
tower = {version = "0.4", features = []}
2021
tower-http = {version = "0.4", features = ["map-response-body", "map-request-body"]}
2122

tests/compression/src/bidirectional_stream.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,45 @@
11
use super::*;
2+
use http_body::Body;
23
use tonic::codec::CompressionEncoding;
34

4-
#[tokio::test(flavor = "multi_thread")]
5-
async fn client_enabled_server_enabled() {
5+
util::parametrized_tests! {
6+
client_enabled_server_enabled,
7+
zstd: CompressionEncoding::Zstd,
8+
gzip: CompressionEncoding::Gzip,
9+
}
10+
11+
#[allow(dead_code)]
12+
async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
613
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
714

815
let svc = test_server::TestServer::new(Svc::default())
9-
.accept_compressed(CompressionEncoding::Gzip)
10-
.send_compressed(CompressionEncoding::Gzip);
16+
.accept_compressed(encoding)
17+
.send_compressed(encoding);
1118

1219
let request_bytes_counter = Arc::new(AtomicUsize::new(0));
1320
let response_bytes_counter = Arc::new(AtomicUsize::new(0));
1421

15-
fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
16-
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
17-
req
22+
#[derive(Clone)]
23+
pub struct AssertRightEncoding {
24+
encoding: CompressionEncoding,
25+
}
26+
27+
#[allow(dead_code)]
28+
impl AssertRightEncoding {
29+
pub fn new(encoding: CompressionEncoding) -> Self {
30+
Self { encoding }
31+
}
32+
33+
pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
34+
let expected = match self.encoding {
35+
CompressionEncoding::Gzip => "gzip",
36+
CompressionEncoding::Zstd => "zstd",
37+
_ => panic!("unexpected encoding {:?}", self.encoding),
38+
};
39+
assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
40+
41+
req
42+
}
1843
}
1944

2045
tokio::spawn({
@@ -24,7 +49,9 @@ async fn client_enabled_server_enabled() {
2449
Server::builder()
2550
.layer(
2651
ServiceBuilder::new()
27-
.map_request(assert_right_encoding)
52+
.map_request(move |req| {
53+
AssertRightEncoding::new(encoding).clone().call(req)
54+
})
2855
.layer(measure_request_body_size_layer(
2956
request_bytes_counter.clone(),
3057
))
@@ -44,8 +71,8 @@ async fn client_enabled_server_enabled() {
4471
});
4572

4673
let mut client = test_client::TestClient::new(mock_io_channel(client).await)
47-
.send_compressed(CompressionEncoding::Gzip)
48-
.accept_compressed(CompressionEncoding::Gzip);
74+
.send_compressed(encoding)
75+
.accept_compressed(encoding);
4976

5077
let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
5178
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
@@ -56,7 +83,12 @@ async fn client_enabled_server_enabled() {
5683
.await
5784
.unwrap();
5885

59-
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
86+
let expected = match encoding {
87+
CompressionEncoding::Gzip => "gzip",
88+
CompressionEncoding::Zstd => "zstd",
89+
_ => panic!("unexpected encoding {:?}", encoding),
90+
};
91+
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
6092

6193
let mut stream: Streaming<SomeData> = res.into_inner();
6294

tests/compression/src/client_stream.rs

Lines changed: 81 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,42 @@
11
use super::*;
2-
use http_body::Body as _;
2+
use http_body::Body;
33
use tonic::codec::CompressionEncoding;
44

5-
#[tokio::test(flavor = "multi_thread")]
6-
async fn client_enabled_server_enabled() {
5+
util::parametrized_tests! {
6+
client_enabled_server_enabled,
7+
zstd: CompressionEncoding::Zstd,
8+
gzip: CompressionEncoding::Gzip,
9+
}
10+
11+
#[allow(dead_code)]
12+
async fn client_enabled_server_enabled(encoding: CompressionEncoding) {
713
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
814

9-
let svc =
10-
test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
15+
let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
1116

1217
let request_bytes_counter = Arc::new(AtomicUsize::new(0));
1318

14-
fn assert_right_encoding<B>(req: http::Request<B>) -> http::Request<B> {
15-
assert_eq!(req.headers().get("grpc-encoding").unwrap(), "gzip");
16-
req
19+
#[derive(Clone)]
20+
pub struct AssertRightEncoding {
21+
encoding: CompressionEncoding,
22+
}
23+
24+
#[allow(dead_code)]
25+
impl AssertRightEncoding {
26+
pub fn new(encoding: CompressionEncoding) -> Self {
27+
Self { encoding }
28+
}
29+
30+
pub fn call<B: Body>(self, req: http::Request<B>) -> http::Request<B> {
31+
let expected = match self.encoding {
32+
CompressionEncoding::Gzip => "gzip",
33+
CompressionEncoding::Zstd => "zstd",
34+
_ => panic!("unexpected encoding {:?}", self.encoding),
35+
};
36+
assert_eq!(req.headers().get("grpc-encoding").unwrap(), expected);
37+
38+
req
39+
}
1740
}
1841

1942
tokio::spawn({
@@ -22,7 +45,9 @@ async fn client_enabled_server_enabled() {
2245
Server::builder()
2346
.layer(
2447
ServiceBuilder::new()
25-
.map_request(assert_right_encoding)
48+
.map_request(move |req| {
49+
AssertRightEncoding::new(encoding).clone().call(req)
50+
})
2651
.layer(measure_request_body_size_layer(
2752
request_bytes_counter.clone(),
2853
))
@@ -35,8 +60,8 @@ async fn client_enabled_server_enabled() {
3560
}
3661
});
3762

38-
let mut client = test_client::TestClient::new(mock_io_channel(client).await)
39-
.send_compressed(CompressionEncoding::Gzip);
63+
let mut client =
64+
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
4065

4166
let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
4267
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
@@ -48,12 +73,17 @@ async fn client_enabled_server_enabled() {
4873
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
4974
}
5075

51-
#[tokio::test(flavor = "multi_thread")]
52-
async fn client_disabled_server_enabled() {
76+
util::parametrized_tests! {
77+
client_disabled_server_enabled,
78+
zstd: CompressionEncoding::Zstd,
79+
gzip: CompressionEncoding::Gzip,
80+
}
81+
82+
#[allow(dead_code)]
83+
async fn client_disabled_server_enabled(encoding: CompressionEncoding) {
5384
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
5485

55-
let svc =
56-
test_server::TestServer::new(Svc::default()).accept_compressed(CompressionEncoding::Gzip);
86+
let svc = test_server::TestServer::new(Svc::default()).accept_compressed(encoding);
5787

5888
let request_bytes_counter = Arc::new(AtomicUsize::new(0));
5989

@@ -93,8 +123,14 @@ async fn client_disabled_server_enabled() {
93123
assert!(bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE);
94124
}
95125

96-
#[tokio::test(flavor = "multi_thread")]
97-
async fn client_enabled_server_disabled() {
126+
util::parametrized_tests! {
127+
client_enabled_server_disabled,
128+
zstd: CompressionEncoding::Zstd,
129+
gzip: CompressionEncoding::Gzip,
130+
}
131+
132+
#[allow(dead_code)]
133+
async fn client_enabled_server_disabled(encoding: CompressionEncoding) {
98134
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
99135

100136
let svc = test_server::TestServer::new(Svc::default());
@@ -107,8 +143,8 @@ async fn client_enabled_server_disabled() {
107143
.unwrap();
108144
});
109145

110-
let mut client = test_client::TestClient::new(mock_io_channel(client).await)
111-
.send_compressed(CompressionEncoding::Gzip);
146+
let mut client =
147+
test_client::TestClient::new(mock_io_channel(client).await).send_compressed(encoding);
112148

113149
let data = [0_u8; UNCOMPRESSED_MIN_BODY_SIZE].to_vec();
114150
let stream = tokio_stream::iter(vec![SomeData { data: data.clone() }, SomeData { data }]);
@@ -117,18 +153,31 @@ async fn client_enabled_server_disabled() {
117153
let status = client.compress_input_client_stream(req).await.unwrap_err();
118154

119155
assert_eq!(status.code(), tonic::Code::Unimplemented);
156+
let expected = match encoding {
157+
CompressionEncoding::Gzip => "gzip",
158+
CompressionEncoding::Zstd => "zstd",
159+
_ => panic!("unexpected encoding {:?}", encoding),
160+
};
120161
assert_eq!(
121162
status.message(),
122-
"Content is compressed with `gzip` which isn't supported"
163+
format!(
164+
"Content is compressed with `{}` which isn't supported",
165+
expected
166+
)
123167
);
124168
}
125169

126-
#[tokio::test(flavor = "multi_thread")]
127-
async fn compressing_response_from_client_stream() {
170+
util::parametrized_tests! {
171+
compressing_response_from_client_stream,
172+
zstd: CompressionEncoding::Zstd,
173+
gzip: CompressionEncoding::Gzip,
174+
}
175+
176+
#[allow(dead_code)]
177+
async fn compressing_response_from_client_stream(encoding: CompressionEncoding) {
128178
let (client, server) = tokio::io::duplex(UNCOMPRESSED_MIN_BODY_SIZE * 10);
129179

130-
let svc =
131-
test_server::TestServer::new(Svc::default()).send_compressed(CompressionEncoding::Gzip);
180+
let svc = test_server::TestServer::new(Svc::default()).send_compressed(encoding);
132181

133182
let response_bytes_counter = Arc::new(AtomicUsize::new(0));
134183

@@ -153,13 +202,18 @@ async fn compressing_response_from_client_stream() {
153202
}
154203
});
155204

156-
let mut client = test_client::TestClient::new(mock_io_channel(client).await)
157-
.accept_compressed(CompressionEncoding::Gzip);
205+
let mut client =
206+
test_client::TestClient::new(mock_io_channel(client).await).accept_compressed(encoding);
158207

159208
let req = Request::new(Box::pin(tokio_stream::empty()));
160209

161210
let res = client.compress_output_client_stream(req).await.unwrap();
162-
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), "gzip");
211+
let expected = match encoding {
212+
CompressionEncoding::Gzip => "gzip",
213+
CompressionEncoding::Zstd => "zstd",
214+
_ => panic!("unexpected encoding {:?}", encoding),
215+
};
216+
assert_eq!(res.metadata().get("grpc-encoding").unwrap(), expected);
163217
let bytes_sent = response_bytes_counter.load(SeqCst);
164218
assert!(bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE);
165219
}

0 commit comments

Comments
 (0)