1
1
use super :: super :: socket:: Socket as InnerSocket ;
2
- use crate :: engineio :: transport:: Transport ;
2
+ use crate :: transport:: Transport ;
3
3
4
- use super :: super :: transports:: { PollingTransport , WebsocketSecureTransport , WebsocketTransport } ;
5
- use crate :: engineio:: packet:: { HandshakePacket , Packet , PacketId , Payload } ;
6
4
use crate :: error:: { Error , Result } ;
5
+ use crate :: header:: HeaderMap ;
6
+ use crate :: packet:: { HandshakePacket , Packet , PacketId , Payload } ;
7
+ use crate :: transports:: { PollingTransport , WebsocketSecureTransport , WebsocketTransport } ;
7
8
use bytes:: Bytes ;
8
9
use native_tls:: TlsConnector ;
9
- use reqwest:: header:: HeaderMap ;
10
10
use std:: convert:: TryFrom ;
11
11
use std:: convert:: TryInto ;
12
12
use std:: fmt:: Debug ;
13
13
use url:: Url ;
14
- use websocket:: header:: Headers ;
15
14
16
15
#[ derive( Clone , Debug ) ]
17
16
pub struct Socket {
@@ -61,8 +60,11 @@ impl SocketBuilder {
61
60
}
62
61
63
62
// Start with polling transport
64
- let transport =
65
- PollingTransport :: new ( url. clone ( ) , self . tls_config . clone ( ) , self . headers . clone ( ) ) ;
63
+ let transport = PollingTransport :: new (
64
+ url. clone ( ) ,
65
+ self . tls_config . clone ( ) ,
66
+ self . headers . clone ( ) . map ( |v| v. try_into ( ) . unwrap ( ) ) ,
67
+ ) ;
66
68
67
69
let handshake: HandshakePacket = Packet :: try_from ( transport. poll ( ) ?) ?. try_into ( ) ?;
68
70
@@ -94,7 +96,11 @@ impl SocketBuilder {
94
96
self . handshake ( ) ?;
95
97
96
98
// Make a polling transport with new sid
97
- let transport = PollingTransport :: new ( self . url , self . tls_config , self . headers ) ;
99
+ let transport = PollingTransport :: new (
100
+ self . url ,
101
+ self . tls_config ,
102
+ self . headers . map ( |v| v. try_into ( ) . unwrap ( ) ) ,
103
+ ) ;
98
104
99
105
// SAFETY: handshake function called previously.
100
106
Ok ( Socket {
@@ -111,7 +117,10 @@ impl SocketBuilder {
111
117
112
118
if self . websocket_upgrade ( ) ? {
113
119
if url. scheme ( ) == "http" {
114
- let transport = WebsocketTransport :: new ( url, self . get_ws_headers ( ) ?) ;
120
+ let transport = WebsocketTransport :: new (
121
+ url,
122
+ self . headers . map ( |headers| headers. try_into ( ) . unwrap ( ) ) ,
123
+ ) ;
115
124
transport. upgrade ( ) ?;
116
125
// SAFETY: handshake function called previously.
117
126
Ok ( Socket {
@@ -137,7 +146,7 @@ impl SocketBuilder {
137
146
let transport = WebsocketSecureTransport :: new (
138
147
url,
139
148
self . tls_config . clone ( ) ,
140
- self . get_ws_headers ( ) ? ,
149
+ self . headers . map ( |v| v . try_into ( ) . unwrap ( ) ) ,
141
150
) ;
142
151
transport. upgrade ( ) ?;
143
152
// SAFETY: handshake function called previously.
@@ -176,20 +185,6 @@ impl SocketBuilder {
176
185
. iter ( )
177
186
. any ( |upgrade| upgrade. to_lowercase ( ) == * "websocket" ) )
178
187
}
179
-
180
- /// Converts Reqwest headers to Websocket headers
181
- fn get_ws_headers ( & self ) -> Result < Option < Headers > > {
182
- let mut headers = Headers :: new ( ) ;
183
- if self . headers . is_some ( ) {
184
- let opening_headers = self . headers . clone ( ) ;
185
- for ( key, val) in opening_headers. unwrap ( ) {
186
- headers. append_raw ( key. unwrap ( ) . to_string ( ) , val. as_bytes ( ) . to_owned ( ) ) ;
187
- }
188
- Ok ( Some ( headers) )
189
- } else {
190
- Ok ( None )
191
- }
192
- }
193
188
}
194
189
195
190
impl Socket {
@@ -298,8 +293,8 @@ impl Socket {
298
293
Ok ( Some ( payload) )
299
294
}
300
295
301
- // Check if the underlying transport client is connected.
302
- pub ( crate ) fn is_connected ( & self ) -> Result < bool > {
296
+ /// Check if the underlying transport client is connected.
297
+ pub fn is_connected ( & self ) -> Result < bool > {
303
298
self . socket . is_connected ( )
304
299
}
305
300
@@ -314,7 +309,7 @@ impl Socket {
314
309
#[ derive( Clone ) ]
315
310
pub struct Iter < ' a > {
316
311
socket : & ' a Socket ,
317
- iter : Option < crate :: engineio :: packet:: IntoIter > ,
312
+ iter : Option < crate :: packet:: IntoIter > ,
318
313
}
319
314
320
315
impl < ' a > Iterator for Iter < ' a > {
@@ -343,13 +338,13 @@ impl<'a> Iterator for Iter<'a> {
343
338
#[ cfg( test) ]
344
339
mod test {
345
340
346
- use crate :: engineio :: packet:: PacketId ;
341
+ use crate :: packet:: PacketId ;
347
342
348
343
use super :: * ;
349
344
350
345
#[ test]
351
346
fn test_illegal_actions ( ) -> Result < ( ) > {
352
- let url = crate :: engineio :: test:: engine_io_server ( ) ?;
347
+ let url = crate :: test:: engine_io_server ( ) ?;
353
348
let mut sut = SocketBuilder :: new ( url. clone ( ) ) . build ( ) ?;
354
349
355
350
assert ! ( sut
@@ -379,7 +374,7 @@ mod test {
379
374
}
380
375
use reqwest:: header:: HOST ;
381
376
382
- use crate :: engineio :: packet:: Packet ;
377
+ use crate :: packet:: Packet ;
383
378
384
379
fn test_connection ( socket : Socket ) -> Result < ( ) > {
385
380
let mut socket = socket;
@@ -428,14 +423,14 @@ mod test {
428
423
429
424
#[ test]
430
425
fn test_connection_dynamic ( ) -> Result < ( ) > {
431
- let url = crate :: engineio :: test:: engine_io_server ( ) ?;
426
+ let url = crate :: test:: engine_io_server ( ) ?;
432
427
let socket = SocketBuilder :: new ( url) . build ( ) ?;
433
428
test_connection ( socket)
434
429
}
435
430
436
431
#[ test]
437
432
fn test_connection_dynamic_secure ( ) -> Result < ( ) > {
438
- let url = crate :: engineio :: test:: engine_io_server_secure ( ) ?;
433
+ let url = crate :: test:: engine_io_server_secure ( ) ?;
439
434
let mut builder = SocketBuilder :: new ( url) ;
440
435
builder = builder. tls_config ( crate :: test:: tls_connector ( ) ?) ;
441
436
let socket = builder. build ( ) ?;
@@ -444,7 +439,7 @@ mod test {
444
439
445
440
#[ test]
446
441
fn test_connection_polling ( ) -> Result < ( ) > {
447
- let url = crate :: engineio :: test:: engine_io_server ( ) ?;
442
+ let url = crate :: test:: engine_io_server ( ) ?;
448
443
let socket = SocketBuilder :: new ( url) . build_polling ( ) ?;
449
444
test_connection ( socket)
450
445
}
@@ -453,10 +448,10 @@ mod test {
453
448
fn test_connection_wss ( ) -> Result < ( ) > {
454
449
let host =
455
450
std:: env:: var ( "ENGINE_IO_SECURE_HOST" ) . unwrap_or_else ( |_| "localhost" . to_owned ( ) ) ;
456
- let url = crate :: engineio :: test:: engine_io_server_secure ( ) ?;
451
+ let url = crate :: test:: engine_io_server_secure ( ) ?;
457
452
458
453
let mut headers = HeaderMap :: new ( ) ;
459
- headers. insert ( HOST , host. parse ( ) . unwrap ( ) ) ;
454
+ headers. insert ( HOST , host) ;
460
455
let mut builder = SocketBuilder :: new ( url) ;
461
456
462
457
builder = builder. tls_config ( crate :: test:: tls_connector ( ) ?) ;
@@ -468,7 +463,7 @@ mod test {
468
463
469
464
#[ test]
470
465
fn test_connection_ws ( ) -> Result < ( ) > {
471
- let url = crate :: engineio :: test:: engine_io_server ( ) ?;
466
+ let url = crate :: test:: engine_io_server ( ) ?;
472
467
473
468
let builder = SocketBuilder :: new ( url) ;
474
469
let socket = builder. build_websocket ( ) ?;
@@ -478,7 +473,7 @@ mod test {
478
473
479
474
#[ test]
480
475
fn test_open_invariants ( ) -> Result < ( ) > {
481
- let url = crate :: engineio :: test:: engine_io_server ( ) ?;
476
+ let url = crate :: test:: engine_io_server ( ) ?;
482
477
let illegal_url = "this is illegal" ;
483
478
484
479
assert ! ( Url :: parse( & illegal_url) . is_err( ) ) ;
@@ -498,7 +493,7 @@ mod test {
498
493
let mut headers = HeaderMap :: new ( ) ;
499
494
let host =
500
495
std:: env:: var ( "ENGINE_IO_SECURE_HOST" ) . unwrap_or_else ( |_| "localhost" . to_owned ( ) ) ;
501
- headers. insert ( HOST , host. parse ( ) . unwrap ( ) ) ;
496
+ headers. insert ( HOST , host) ;
502
497
503
498
let _ = SocketBuilder :: new ( url. clone ( ) )
504
499
. tls_config (
0 commit comments