@@ -420,7 +420,7 @@ pub mod subscriptions {
420
420
} ,
421
421
} ;
422
422
423
- use futures:: { channel:: mpsc, Future , StreamExt as _} ;
423
+ use futures:: { channel:: mpsc, Future , StreamExt as _, TryStreamExt as _ } ;
424
424
use juniper:: { http:: GraphQLRequest , InputValue , ScalarValue , SubscriptionCoordinator as _} ;
425
425
use juniper_subscriptions:: Coordinator ;
426
426
use serde:: { Deserialize , Serialize } ;
@@ -458,71 +458,71 @@ pub mod subscriptions {
458
458
let context = Arc :: new ( context) ;
459
459
let running = Arc :: new ( AtomicBool :: new ( false ) ) ;
460
460
let got_close_signal = Arc :: new ( AtomicBool :: new ( false ) ) ;
461
+ let got_close_signal2 = got_close_signal. clone ( ) ;
461
462
462
- sink_rx. fold ( Ok ( ( ) ) , move |_, msg| {
463
- let coordinator = coordinator. clone ( ) ;
464
- let context = context. clone ( ) ;
465
- let running = running. clone ( ) ;
466
- let got_close_signal = got_close_signal. clone ( ) ;
467
- let ws_tx = ws_tx. clone ( ) ;
468
-
469
- async move {
470
- let msg = match msg {
471
- Ok ( m) => m,
472
- Err ( e) => {
473
- got_close_signal. store ( true , Ordering :: Relaxed ) ;
474
- return Err ( failure:: format_err!( "Websocket error: {}" , e) ) ;
463
+ sink_rx
464
+ . map_err ( move |e| {
465
+ got_close_signal2. store ( true , Ordering :: Relaxed ) ;
466
+ failure:: format_err!( "Websocket error: {}" , e)
467
+ } )
468
+ . try_fold ( ( ) , move |_, msg| {
469
+ let coordinator = coordinator. clone ( ) ;
470
+ let context = context. clone ( ) ;
471
+ let running = running. clone ( ) ;
472
+ let got_close_signal = got_close_signal. clone ( ) ;
473
+ let ws_tx = ws_tx. clone ( ) ;
474
+
475
+ async move {
476
+ if msg. is_close ( ) {
477
+ return Ok ( ( ) ) ;
475
478
}
476
- } ;
477
-
478
- if msg. is_close ( ) {
479
- return Ok ( ( ) ) ;
480
- }
481
479
482
- let msg = msg
483
- . to_str ( )
484
- . map_err ( |_| failure:: format_err!( "Non-text messages are not accepted" ) ) ?;
485
- let request: WsPayload < S > = serde_json:: from_str ( msg)
486
- . map_err ( |e| failure:: format_err!( "Invalid WsPayload: {}" , e) ) ?;
487
-
488
- match request. type_name . as_str ( ) {
489
- "connection_init" => { }
490
- "start" => {
491
- {
492
- let closed = got_close_signal. load ( Ordering :: Relaxed ) ;
493
- if closed {
494
- return Ok ( ( ) ) ;
480
+ let msg = msg
481
+ . to_str ( )
482
+ . map_err ( |_| failure:: format_err!( "Non-text messages are not accepted" ) ) ?;
483
+ let request: WsPayload < S > = serde_json:: from_str ( msg)
484
+ . map_err ( |e| failure:: format_err!( "Invalid WsPayload: {}" , e) ) ?;
485
+
486
+ match request. type_name . as_str ( ) {
487
+ "connection_init" => { }
488
+ "start" => {
489
+ {
490
+ let closed = got_close_signal. load ( Ordering :: Relaxed ) ;
491
+ if closed {
492
+ return Ok ( ( ) ) ;
493
+ }
494
+
495
+ if running. load ( Ordering :: Relaxed ) {
496
+ return Ok ( ( ) ) ;
497
+ }
498
+ running. store ( true , Ordering :: Relaxed ) ;
495
499
}
496
500
497
- if running. load ( Ordering :: Relaxed ) {
498
- return Ok ( ( ) ) ;
499
- }
500
- running. store ( true , Ordering :: Relaxed ) ;
501
- }
502
-
503
- let ws_tx = ws_tx. clone ( ) ;
501
+ let ws_tx = ws_tx. clone ( ) ;
504
502
505
- if let Some ( ref payload) = request. payload {
506
- if payload. query . is_none ( ) {
507
- return Err ( failure:: format_err!( "Query not found" ) ) ;
503
+ if let Some ( ref payload) = request. payload {
504
+ if payload. query . is_none ( ) {
505
+ return Err ( failure:: format_err!( "Query not found" ) ) ;
506
+ }
507
+ } else {
508
+ return Err ( failure:: format_err!( "Payload not found" ) ) ;
508
509
}
509
- } else {
510
- return Err ( failure:: format_err!( "Payload not found" ) ) ;
511
- }
512
510
513
- tokio:: task:: spawn ( async move {
514
- let payload = request. payload . unwrap ( ) ;
511
+ tokio:: task:: spawn ( async move {
512
+ let payload = request. payload . unwrap ( ) ;
515
513
516
- let request_id = request. id . unwrap_or ( "1" . to_owned ( ) ) ;
514
+ let request_id = request. id . unwrap_or ( "1" . to_owned ( ) ) ;
517
515
518
- let graphql_request = GraphQLRequest :: < S > :: new (
519
- payload. query . unwrap ( ) ,
520
- None ,
521
- payload. variables ,
522
- ) ;
516
+ let graphql_request = GraphQLRequest :: < S > :: new (
517
+ payload. query . unwrap ( ) ,
518
+ None ,
519
+ payload. variables ,
520
+ ) ;
523
521
524
- let values_stream =
525
- match coordinator. subscribe ( & graphql_request, & context) . await {
522
+ let values_stream = match coordinator
523
+ . subscribe ( & graphql_request, & context)
524
+ . await
525
+ {
526
526
Ok ( s) => s,
527
527
Err ( err) => {
528
528
let _ =
@@ -546,48 +546,51 @@ pub mod subscriptions {
546
546
}
547
547
} ;
548
548
549
- values_stream
550
- . take_while ( move |response| {
551
- let request_id = request_id. clone ( ) ;
552
- let closed = got_close_signal. load ( Ordering :: Relaxed ) ;
553
- if !closed {
554
- let mut response_text = serde_json:: to_string ( & response)
549
+ values_stream
550
+ . take_while ( move |response| {
551
+ let request_id = request_id. clone ( ) ;
552
+ let closed = got_close_signal. load ( Ordering :: Relaxed ) ;
553
+ if !closed {
554
+ let mut response_text = serde_json:: to_string (
555
+ & response,
556
+ )
555
557
. unwrap_or ( "Error deserializing response" . to_owned ( ) ) ;
556
558
557
- response_text = format ! (
558
- r#"{{"type":"data","id":"{}","payload":{} }}"# ,
559
- request_id, response_text
560
- ) ;
559
+ response_text = format ! (
560
+ r#"{{"type":"data","id":"{}","payload":{} }}"# ,
561
+ request_id, response_text
562
+ ) ;
563
+
564
+ let _ = ws_tx. unbounded_send ( Some ( Ok ( Message :: text (
565
+ response_text,
566
+ ) ) ) ) ;
567
+ }
568
+
569
+ async move { !closed }
570
+ } )
571
+ . for_each ( |_| async { } )
572
+ . await ;
573
+ } ) ;
574
+ }
575
+ "stop" => {
576
+ got_close_signal. store ( true , Ordering :: Relaxed ) ;
561
577
562
- let _ = ws_tx
563
- . unbounded_send ( Some ( Ok ( Message :: text ( response_text) ) ) ) ;
564
- }
578
+ let request_id = request. id . unwrap_or ( "1" . to_owned ( ) ) ;
579
+ let close_message = format ! (
580
+ r#"{{"type":"complete","id":"{}","payload":null}}"# ,
581
+ request_id
582
+ ) ;
583
+ let _ = ws_tx. unbounded_send ( Some ( Ok ( Message :: text ( close_message) ) ) ) ;
565
584
566
- async move { !closed }
567
- } )
568
- . for_each ( |_| async { } )
569
- . await ;
570
- } ) ;
571
- }
572
- "stop" => {
573
- got_close_signal. store ( true , Ordering :: Relaxed ) ;
574
-
575
- let request_id = request. id . unwrap_or ( "1" . to_owned ( ) ) ;
576
- let close_message = format ! (
577
- r#"{{"type":"complete","id":"{}","payload":null}}"# ,
578
- request_id
579
- ) ;
580
- let _ = ws_tx. unbounded_send ( Some ( Ok ( Message :: text ( close_message) ) ) ) ;
581
-
582
- // close channel
583
- let _ = ws_tx. unbounded_send ( None ) ;
585
+ // close channel
586
+ let _ = ws_tx. unbounded_send ( None ) ;
587
+ }
588
+ _ => { }
584
589
}
585
- _ => { }
586
- }
587
590
588
- Ok ( ( ) )
589
- }
590
- } )
591
+ Ok ( ( ) )
592
+ }
593
+ } )
591
594
}
592
595
593
596
#[ derive( Deserialize ) ]
0 commit comments