@@ -6,7 +6,7 @@ use crate::{
66} ;
77use axum:: extract:: { FromRequest , Request } ;
88use axum_core:: response:: IntoResponse ;
9- use bytes:: Bytes ;
9+ use bytes:: { Buf as _ , Bytes } ;
1010use dioxus_fullstack_core:: { HttpError , RequestError } ;
1111use futures:: { Stream , StreamExt } ;
1212#[ cfg( feature = "server" ) ]
@@ -276,18 +276,8 @@ impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding> FromResponse
276276{
277277 fn from_response ( res : ClientResponse ) -> impl Future < Output = Result < Self , ServerFnError > > {
278278 SendWrapper :: new ( async move {
279- let client_stream = Box :: pin ( SendWrapper :: new ( res. bytes_stream ( ) . map (
280- |byte| match byte {
281- Ok ( bytes) => match decode_stream_frame :: < T , E > ( bytes) {
282- Some ( res) => Ok ( res) ,
283- None => Err ( StreamingError :: Decoding ) ,
284- } ,
285- Err ( _) => Err ( StreamingError :: Failed ) ,
286- } ,
287- ) ) ) ;
288-
289279 Ok ( Self {
290- stream : client_stream ,
280+ stream : byte_stream_to_client_stream :: < E , _ , _ , _ > ( res . bytes_stream ( ) ) ,
291281 encoding : PhantomData ,
292282 } )
293283 } )
@@ -385,13 +375,7 @@ impl<T: DeserializeOwned + Serialize + 'static + Send, E: Encoding, S> FromReque
385375 let stream = body. into_data_stream ( ) ;
386376
387377 Ok ( Self {
388- stream : Box :: pin ( stream. map ( |byte| match byte {
389- Ok ( bytes) => match decode_stream_frame :: < T , E > ( bytes) {
390- Some ( res) => Ok ( res) ,
391- None => Err ( StreamingError :: Decoding ) ,
392- } ,
393- Err ( _) => Err ( StreamingError :: Failed ) ,
394- } ) ) ,
378+ stream : byte_stream_to_client_stream :: < E , _ , _ , _ > ( stream) ,
395379 encoding : PhantomData ,
396380 } )
397381 }
@@ -504,22 +488,98 @@ pub fn encode_stream_frame<T: Serialize, E: Encoding>(data: T) -> Option<Bytes>
504488 Some ( Bytes :: from ( bytes) . slice ( offset..) )
505489}
506490
491+ fn byte_stream_to_client_stream < E , T , S , E1 > (
492+ stream : S ,
493+ ) -> Pin < Box < dyn Stream < Item = Result < T , StreamingError > > + Send > >
494+ where
495+ S : Stream < Item = Result < Bytes , E1 > > + ' static + Send ,
496+ E : Encoding ,
497+ T : DeserializeOwned + ' static ,
498+ {
499+ Box :: pin ( stream. flat_map ( |bytes| {
500+ enum DecodeIteratorState {
501+ Empty ,
502+ Failed ,
503+ Checked ( Bytes ) ,
504+ UnChecked ( Bytes ) ,
505+ }
506+
507+ let mut state = match bytes {
508+ Ok ( bytes) => DecodeIteratorState :: UnChecked ( bytes) ,
509+ Err ( _) => DecodeIteratorState :: Failed ,
510+ } ;
511+
512+ futures:: stream:: iter ( std:: iter:: from_fn ( move || {
513+ match std:: mem:: replace ( & mut state, DecodeIteratorState :: Empty ) {
514+ DecodeIteratorState :: Empty => None ,
515+ DecodeIteratorState :: Failed => Some ( Err ( StreamingError :: Failed ) ) ,
516+ DecodeIteratorState :: Checked ( mut bytes) => {
517+ let r = decode_stream_frame_multi :: < T , E > ( & mut bytes) ;
518+ if r. is_some ( ) {
519+ state = DecodeIteratorState :: Checked ( bytes)
520+ }
521+ r
522+ }
523+ DecodeIteratorState :: UnChecked ( mut bytes) => {
524+ let r = decode_stream_frame_multi :: < T , E > ( & mut bytes) ;
525+ if r. is_some ( ) {
526+ state = DecodeIteratorState :: Checked ( bytes) ;
527+ r
528+ } else {
529+ Some ( Err ( StreamingError :: Decoding ) )
530+ }
531+ }
532+ }
533+ } ) )
534+ } ) )
535+ }
536+
507537/// Decode a websocket-framed streaming payload produced by [`encode_stream_frame`].
508538///
509539/// This function returns `None` if the frame is invalid or cannot be decoded.
510540///
511541/// It cannot handle masked frames, as those are not produced by our encoding function.
512- pub fn decode_stream_frame < T , E > ( frame : Bytes ) -> Option < T >
542+ pub fn decode_stream_frame < T , E > ( mut frame : Bytes ) -> Option < T >
513543where
514544 E : Encoding ,
515545 T : DeserializeOwned ,
516546{
547+ decode_stream_frame_multi :: < T , E > ( & mut frame) . and_then ( |r| r. ok ( ) )
548+ }
549+
550+ /// Decode one value and advance the bytes pointer
551+ ///
552+ /// If the frame is empty return None.
553+ ///
554+ /// Otherwise, if the initial opcode is not the one expected for binary stream
555+ /// or the frame is not large enough return error StreamingError::Decoding
556+ fn decode_stream_frame_multi < T , E > ( frame : & mut Bytes ) -> Option < Result < T , StreamingError > >
557+ where
558+ E : Encoding ,
559+ T : DeserializeOwned ,
560+ {
561+ let ( offset, payload_len) = match offset_payload_len ( frame) ? {
562+ Ok ( r) => r,
563+ Err ( e) => return Some ( Err ( e) ) ,
564+ } ;
565+
566+ let r = E :: decode ( frame. slice ( offset..offset + payload_len) ) ;
567+ frame. advance ( offset + payload_len) ;
568+ r. map ( |r| Ok ( r) )
569+ }
570+
571+ /// Compute (offset,len) for decoding data
572+ fn offset_payload_len ( frame : & Bytes ) -> Option < Result < ( usize , usize ) , StreamingError > > {
517573 let data = frame. as_ref ( ) ;
518574
519- if data. len ( ) < 2 {
575+ if data. is_empty ( ) {
520576 return None ;
521577 }
522578
579+ if data. len ( ) < 2 {
580+ return Some ( Err ( StreamingError :: Decoding ) ) ;
581+ }
582+
523583 let first = data[ 0 ] ;
524584 let second = data[ 1 ] ;
525585
@@ -528,44 +588,43 @@ where
528588 let opcode = first & 0x0F ;
529589 let rsv = first & 0x70 ;
530590 if !fin || opcode != 0x02 || rsv != 0 {
531- return None ;
591+ return Some ( Err ( StreamingError :: Decoding ) ) ;
532592 }
533593
534594 // Mask bit must be zero for our framing
535595 if second & 0x80 != 0 {
536- return None ;
596+ return Some ( Err ( StreamingError :: Decoding ) ) ;
537597 }
538598
539599 let mut offset = 2usize ;
540600 let mut payload_len = ( second & 0x7F ) as usize ;
541601
542602 if payload_len == 126 {
543603 if data. len ( ) < offset + 2 {
544- return None ;
604+ return Some ( Err ( StreamingError :: Decoding ) ) ;
545605 }
546606
547607 payload_len = u16:: from_be_bytes ( [ data[ offset] , data[ offset + 1 ] ] ) as usize ;
548608 offset += 2 ;
549609 } else if payload_len == 127 {
550610 if data. len ( ) < offset + 8 {
551- return None ;
611+ return Some ( Err ( StreamingError :: Decoding ) ) ;
552612 }
553613
554614 let mut len_bytes = [ 0u8 ; 8 ] ;
555615 len_bytes. copy_from_slice ( & data[ offset..offset + 8 ] ) ;
556616 let len_u64 = u64:: from_be_bytes ( len_bytes) ;
557617
558618 if len_u64 > usize:: MAX as u64 {
559- return None ;
619+ return Some ( Err ( StreamingError :: Decoding ) ) ;
560620 }
561621
562622 payload_len = len_u64 as usize ;
563623 offset += 8 ;
564624 }
565625
566626 if data. len ( ) < offset + payload_len {
567- return None ;
627+ return Some ( Err ( StreamingError :: Decoding ) ) ;
568628 }
569-
570- E :: decode ( frame. slice ( offset..offset + payload_len) )
629+ Some ( Ok ( ( offset, payload_len) ) )
571630}
0 commit comments