@@ -23,35 +23,91 @@ pub trait BytesCursor {
23
23
24
24
impl BytesCursor for io:: Cursor < Bytes > {
25
25
fn remaining ( & self ) -> usize {
26
- self . get_ref ( ) . len ( ) - self . position ( ) as usize
26
+ // We have to use a saturating sub here because the position can be
27
+ // greater than the length of the bytes.
28
+ self . get_ref ( ) . len ( ) . saturating_sub ( self . position ( ) as usize )
27
29
}
28
30
29
31
fn extract_remaining ( & mut self ) -> Bytes {
30
- self . extract_bytes ( self . remaining ( ) )
31
- . expect ( "somehow we read past the end of the file" )
32
+ // We don't really care if we fail here since the desired behavior is
33
+ // to return all bytes remaining in the cursor. If we fail its because
34
+ // there are not enough bytes left in the cursor to read.
35
+ self . extract_bytes ( self . remaining ( ) ) . unwrap_or_default ( )
32
36
}
33
37
34
38
fn extract_bytes ( & mut self , size : usize ) -> io:: Result < Bytes > {
35
- let position = self . position ( ) as usize ;
36
- if position + size > self . get_ref ( ) . len ( ) {
39
+ // If the size is zero we can just return an empty bytes slice.
40
+ if size == 0 {
41
+ return Ok ( Bytes :: new ( ) ) ;
42
+ }
43
+
44
+ // If the size is greater than the remaining bytes we can just return an
45
+ // error.
46
+ if size > self . remaining ( ) {
37
47
return Err ( io:: Error :: new ( io:: ErrorKind :: UnexpectedEof , "not enough bytes" ) ) ;
38
48
}
39
49
50
+ let position = self . position ( ) as usize ;
51
+
52
+ // We slice bytes here which is a O(1) operation as it only modifies a few
53
+ // reference counters and does not copy the memory.
40
54
let slice = self . get_ref ( ) . slice ( position..position + size) ;
55
+
56
+ // We advance the cursor because we have now "read" the bytes.
41
57
self . set_position ( ( position + size) as u64 ) ;
42
58
43
59
Ok ( slice)
44
60
}
45
61
}
46
62
47
63
#[ cfg( test) ]
64
+ #[ cfg_attr( all( test, coverage_nightly) , coverage( off) ) ]
48
65
mod tests {
49
66
use super :: * ;
50
67
51
68
#[ test]
52
- fn test_bytes_cursor ( ) {
69
+ fn test_bytes_cursor_extract_remaining ( ) {
53
70
let mut cursor = io:: Cursor :: new ( Bytes :: from_static ( & [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
54
71
let remaining = cursor. extract_remaining ( ) ;
55
72
assert_eq ! ( remaining, Bytes :: from_static( & [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
56
73
}
74
+
75
+ #[ test]
76
+ fn test_bytes_cursor_extract_bytes ( ) {
77
+ let mut cursor = io:: Cursor :: new ( Bytes :: from_static ( & [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
78
+ let bytes = cursor. extract_bytes ( 3 ) . unwrap ( ) ;
79
+ assert_eq ! ( bytes, Bytes :: from_static( & [ 1 , 2 , 3 ] ) ) ;
80
+ assert_eq ! ( cursor. remaining( ) , 2 ) ;
81
+
82
+ let bytes = cursor. extract_bytes ( 2 ) . unwrap ( ) ;
83
+ assert_eq ! ( bytes, Bytes :: from_static( & [ 4 , 5 ] ) ) ;
84
+ assert_eq ! ( cursor. remaining( ) , 0 ) ;
85
+
86
+ let bytes = cursor. extract_bytes ( 1 ) . unwrap_err ( ) ;
87
+ assert_eq ! ( bytes. kind( ) , io:: ErrorKind :: UnexpectedEof ) ;
88
+
89
+ let bytes = cursor. extract_bytes ( 0 ) . unwrap ( ) ;
90
+ assert_eq ! ( bytes, Bytes :: from_static( & [ ] ) ) ;
91
+ assert_eq ! ( cursor. remaining( ) , 0 ) ;
92
+
93
+ let bytes = cursor. extract_remaining ( ) ;
94
+ assert_eq ! ( bytes, Bytes :: from_static( & [ ] ) ) ;
95
+ assert_eq ! ( cursor. remaining( ) , 0 ) ;
96
+ }
97
+
98
+ #[ test]
99
+ fn seek_out_of_bounds ( ) {
100
+ let mut cursor = io:: Cursor :: new ( Bytes :: from_static ( & [ 1 , 2 , 3 , 4 , 5 ] ) ) ;
101
+ cursor. set_position ( 10 ) ;
102
+ assert_eq ! ( cursor. remaining( ) , 0 ) ;
103
+
104
+ let bytes = cursor. extract_remaining ( ) ;
105
+ assert_eq ! ( bytes, Bytes :: from_static( & [ ] ) ) ;
106
+
107
+ let bytes = cursor. extract_bytes ( 1 ) ;
108
+ assert_eq ! ( bytes. unwrap_err( ) . kind( ) , io:: ErrorKind :: UnexpectedEof ) ;
109
+
110
+ let bytes = cursor. extract_bytes ( 0 ) ;
111
+ assert_eq ! ( bytes. unwrap( ) , Bytes :: from_static( & [ ] ) ) ;
112
+ }
57
113
}
0 commit comments