1
1
use super :: * ;
2
- use http_body:: Body as _ ;
2
+ use http_body:: Body ;
3
3
use tonic:: codec:: CompressionEncoding ;
4
4
5
- #[ tokio:: test( flavor = "multi_thread" ) ]
6
- async fn client_enabled_server_enabled ( ) {
5
+ util:: parametrized_tests! {
6
+ client_enabled_server_enabled,
7
+ zstd: CompressionEncoding :: Zstd ,
8
+ gzip: CompressionEncoding :: Gzip ,
9
+ }
10
+
11
+ #[ allow( dead_code) ]
12
+ async fn client_enabled_server_enabled ( encoding : CompressionEncoding ) {
7
13
let ( client, server) = tokio:: io:: duplex ( UNCOMPRESSED_MIN_BODY_SIZE * 10 ) ;
8
14
9
- let svc =
10
- test_server:: TestServer :: new ( Svc :: default ( ) ) . accept_compressed ( CompressionEncoding :: Gzip ) ;
15
+ let svc = test_server:: TestServer :: new ( Svc :: default ( ) ) . accept_compressed ( encoding) ;
11
16
12
17
let request_bytes_counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
13
18
14
- fn assert_right_encoding < B > ( req : http:: Request < B > ) -> http:: Request < B > {
15
- assert_eq ! ( req. headers( ) . get( "grpc-encoding" ) . unwrap( ) , "gzip" ) ;
16
- req
19
+ #[ derive( Clone ) ]
20
+ pub struct AssertRightEncoding {
21
+ encoding : CompressionEncoding ,
22
+ }
23
+
24
+ #[ allow( dead_code) ]
25
+ impl AssertRightEncoding {
26
+ pub fn new ( encoding : CompressionEncoding ) -> Self {
27
+ Self { encoding }
28
+ }
29
+
30
+ pub fn call < B : Body > ( self , req : http:: Request < B > ) -> http:: Request < B > {
31
+ let expected = match self . encoding {
32
+ CompressionEncoding :: Gzip => "gzip" ,
33
+ CompressionEncoding :: Zstd => "zstd" ,
34
+ _ => panic ! ( "unexpected encoding {:?}" , self . encoding) ,
35
+ } ;
36
+ assert_eq ! ( req. headers( ) . get( "grpc-encoding" ) . unwrap( ) , expected) ;
37
+
38
+ req
39
+ }
17
40
}
18
41
19
42
tokio:: spawn ( {
@@ -22,7 +45,9 @@ async fn client_enabled_server_enabled() {
22
45
Server :: builder ( )
23
46
. layer (
24
47
ServiceBuilder :: new ( )
25
- . map_request ( assert_right_encoding)
48
+ . map_request ( move |req| {
49
+ AssertRightEncoding :: new ( encoding) . clone ( ) . call ( req)
50
+ } )
26
51
. layer ( measure_request_body_size_layer (
27
52
request_bytes_counter. clone ( ) ,
28
53
) )
@@ -35,8 +60,8 @@ async fn client_enabled_server_enabled() {
35
60
}
36
61
} ) ;
37
62
38
- let mut client = test_client :: TestClient :: new ( mock_io_channel ( client ) . await )
39
- . send_compressed ( CompressionEncoding :: Gzip ) ;
63
+ let mut client =
64
+ test_client :: TestClient :: new ( mock_io_channel ( client ) . await ) . send_compressed ( encoding ) ;
40
65
41
66
let data = [ 0_u8 ; UNCOMPRESSED_MIN_BODY_SIZE ] . to_vec ( ) ;
42
67
let stream = tokio_stream:: iter ( vec ! [ SomeData { data: data. clone( ) } , SomeData { data } ] ) ;
@@ -48,12 +73,17 @@ async fn client_enabled_server_enabled() {
48
73
assert ! ( bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE ) ;
49
74
}
50
75
51
- #[ tokio:: test( flavor = "multi_thread" ) ]
52
- async fn client_disabled_server_enabled ( ) {
76
+ util:: parametrized_tests! {
77
+ client_disabled_server_enabled,
78
+ zstd: CompressionEncoding :: Zstd ,
79
+ gzip: CompressionEncoding :: Gzip ,
80
+ }
81
+
82
+ #[ allow( dead_code) ]
83
+ async fn client_disabled_server_enabled ( encoding : CompressionEncoding ) {
53
84
let ( client, server) = tokio:: io:: duplex ( UNCOMPRESSED_MIN_BODY_SIZE * 10 ) ;
54
85
55
- let svc =
56
- test_server:: TestServer :: new ( Svc :: default ( ) ) . accept_compressed ( CompressionEncoding :: Gzip ) ;
86
+ let svc = test_server:: TestServer :: new ( Svc :: default ( ) ) . accept_compressed ( encoding) ;
57
87
58
88
let request_bytes_counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
59
89
@@ -93,8 +123,14 @@ async fn client_disabled_server_enabled() {
93
123
assert ! ( bytes_sent > UNCOMPRESSED_MIN_BODY_SIZE ) ;
94
124
}
95
125
96
- #[ tokio:: test( flavor = "multi_thread" ) ]
97
- async fn client_enabled_server_disabled ( ) {
126
+ util:: parametrized_tests! {
127
+ client_enabled_server_disabled,
128
+ zstd: CompressionEncoding :: Zstd ,
129
+ gzip: CompressionEncoding :: Gzip ,
130
+ }
131
+
132
+ #[ allow( dead_code) ]
133
+ async fn client_enabled_server_disabled ( encoding : CompressionEncoding ) {
98
134
let ( client, server) = tokio:: io:: duplex ( UNCOMPRESSED_MIN_BODY_SIZE * 10 ) ;
99
135
100
136
let svc = test_server:: TestServer :: new ( Svc :: default ( ) ) ;
@@ -107,8 +143,8 @@ async fn client_enabled_server_disabled() {
107
143
. unwrap ( ) ;
108
144
} ) ;
109
145
110
- let mut client = test_client :: TestClient :: new ( mock_io_channel ( client ) . await )
111
- . send_compressed ( CompressionEncoding :: Gzip ) ;
146
+ let mut client =
147
+ test_client :: TestClient :: new ( mock_io_channel ( client ) . await ) . send_compressed ( encoding ) ;
112
148
113
149
let data = [ 0_u8 ; UNCOMPRESSED_MIN_BODY_SIZE ] . to_vec ( ) ;
114
150
let stream = tokio_stream:: iter ( vec ! [ SomeData { data: data. clone( ) } , SomeData { data } ] ) ;
@@ -117,18 +153,31 @@ async fn client_enabled_server_disabled() {
117
153
let status = client. compress_input_client_stream ( req) . await . unwrap_err ( ) ;
118
154
119
155
assert_eq ! ( status. code( ) , tonic:: Code :: Unimplemented ) ;
156
+ let expected = match encoding {
157
+ CompressionEncoding :: Gzip => "gzip" ,
158
+ CompressionEncoding :: Zstd => "zstd" ,
159
+ _ => panic ! ( "unexpected encoding {:?}" , encoding) ,
160
+ } ;
120
161
assert_eq ! (
121
162
status. message( ) ,
122
- "Content is compressed with `gzip` which isn't supported"
163
+ format!(
164
+ "Content is compressed with `{}` which isn't supported" ,
165
+ expected
166
+ )
123
167
) ;
124
168
}
125
169
126
- #[ tokio:: test( flavor = "multi_thread" ) ]
127
- async fn compressing_response_from_client_stream ( ) {
170
+ util:: parametrized_tests! {
171
+ compressing_response_from_client_stream,
172
+ zstd: CompressionEncoding :: Zstd ,
173
+ gzip: CompressionEncoding :: Gzip ,
174
+ }
175
+
176
+ #[ allow( dead_code) ]
177
+ async fn compressing_response_from_client_stream ( encoding : CompressionEncoding ) {
128
178
let ( client, server) = tokio:: io:: duplex ( UNCOMPRESSED_MIN_BODY_SIZE * 10 ) ;
129
179
130
- let svc =
131
- test_server:: TestServer :: new ( Svc :: default ( ) ) . send_compressed ( CompressionEncoding :: Gzip ) ;
180
+ let svc = test_server:: TestServer :: new ( Svc :: default ( ) ) . send_compressed ( encoding) ;
132
181
133
182
let response_bytes_counter = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
134
183
@@ -153,13 +202,18 @@ async fn compressing_response_from_client_stream() {
153
202
}
154
203
} ) ;
155
204
156
- let mut client = test_client :: TestClient :: new ( mock_io_channel ( client ) . await )
157
- . accept_compressed ( CompressionEncoding :: Gzip ) ;
205
+ let mut client =
206
+ test_client :: TestClient :: new ( mock_io_channel ( client ) . await ) . accept_compressed ( encoding ) ;
158
207
159
208
let req = Request :: new ( Box :: pin ( tokio_stream:: empty ( ) ) ) ;
160
209
161
210
let res = client. compress_output_client_stream ( req) . await . unwrap ( ) ;
162
- assert_eq ! ( res. metadata( ) . get( "grpc-encoding" ) . unwrap( ) , "gzip" ) ;
211
+ let expected = match encoding {
212
+ CompressionEncoding :: Gzip => "gzip" ,
213
+ CompressionEncoding :: Zstd => "zstd" ,
214
+ _ => panic ! ( "unexpected encoding {:?}" , encoding) ,
215
+ } ;
216
+ assert_eq ! ( res. metadata( ) . get( "grpc-encoding" ) . unwrap( ) , expected) ;
163
217
let bytes_sent = response_bytes_counter. load ( SeqCst ) ;
164
218
assert ! ( bytes_sent < UNCOMPRESSED_MIN_BODY_SIZE ) ;
165
219
}
0 commit comments