Skip to content

Commit d772c0f

Browse files
RiverPhillipsroypat
authored andcommitted
Use u32 for Vsock related buffer sizes
Use u32 for all Vsock related buffers instead of usize, the Virtio specification states that lengths of virtio-buffers fit into a u32. Signed-off-by: River Phillips <[email protected]>
1 parent 946cf20 commit d772c0f

File tree

4 files changed

+47
-42
lines changed

4 files changed

+47
-42
lines changed

src/vmm/src/devices/virtio/vsock/csm/connection.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,7 @@ where
237237
// length of the read data.
238238
// Safe to unwrap because read_cnt is no more than max_len, which is bounded
239239
// by self.peer_avail_credit(), a u32 internally.
240-
pkt.set_op(uapi::VSOCK_OP_RW)
241-
.set_len(u32::try_from(read_cnt).unwrap());
240+
pkt.set_op(uapi::VSOCK_OP_RW).set_len(read_cnt);
242241
METRICS.rx_bytes_count.add(read_cnt as u64);
243242
}
244243
self.rx_cnt += Wrapping(pkt.len());
@@ -605,7 +604,7 @@ where
605604
/// Raw data can either be sent straight to the host stream, or to our TX buffer, if the
606605
/// former fails.
607606
fn send_bytes(&mut self, pkt: &VsockPacket) -> Result<(), VsockError> {
608-
let len = pkt.len() as usize;
607+
let len = pkt.len();
609608

610609
// If there is data in the TX buffer, that means we're already registered for EPOLLOUT
611610
// events on the underlying stream. Therefore, there's no point in attempting a write
@@ -635,7 +634,7 @@ where
635634
};
636635
// Move the "forwarded bytes" counter ahead by how much we were able to send out.
637636
// Safe to unwrap because the maximum value is pkt.len(), which is a u32.
638-
self.fwd_cnt += wrap_usize_to_u32(written);
637+
self.fwd_cnt += written;
639638
METRICS.tx_bytes_count.add(written as u64);
640639

641640
// If we couldn't write the whole slice, we'll need to push the remaining data to our
@@ -662,8 +661,8 @@ where
662661

663662
/// Get the maximum number of bytes that we can send to our peer, without overflowing its
664663
/// buffer.
665-
fn peer_avail_credit(&self) -> usize {
666-
(Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0 as usize
664+
fn peer_avail_credit(&self) -> u32 {
665+
(Wrapping(self.peer_buf_alloc) - (self.rx_cnt - self.peer_fwd_cnt)).0
667666
}
668667

669668
/// Prepare a packet header for transmission to our peer.
@@ -916,7 +915,7 @@ mod tests {
916915
assert!(credit < self.conn.peer_buf_alloc);
917916
self.conn.peer_fwd_cnt = Wrapping(0);
918917
self.conn.rx_cnt = Wrapping(self.conn.peer_buf_alloc - credit);
919-
assert_eq!(self.conn.peer_avail_credit(), credit as usize);
918+
assert_eq!(self.conn.peer_avail_credit(), credit);
920919
}
921920

922921
fn send(&mut self) {
@@ -941,11 +940,13 @@ mod tests {
941940
}
942941

943942
fn init_data_tx_pkt(&mut self, mut data: &[u8]) -> &VsockPacket {
944-
assert!(data.len() <= self.tx_pkt.buf_size());
943+
assert!(data.len() <= self.tx_pkt.buf_size() as usize);
945944
self.init_tx_pkt(uapi::VSOCK_OP_RW, u32::try_from(data.len()).unwrap());
946945

947946
let len = data.len();
948-
self.rx_pkt.read_at_offset_from(&mut data, 0, len).unwrap();
947+
self.rx_pkt
948+
.read_at_offset_from(&mut data, 0, len.try_into().unwrap())
949+
.unwrap();
949950
&self.tx_pkt
950951
}
951952
}
@@ -1282,7 +1283,7 @@ mod tests {
12821283
ctx.set_stream(stream);
12831284

12841285
// Fill up the TX buffer.
1285-
let data = vec![0u8; ctx.tx_pkt.buf_size()];
1286+
let data = vec![0u8; ctx.tx_pkt.buf_size() as usize];
12861287
ctx.init_data_tx_pkt(data.as_slice());
12871288
for _i in 0..(csm_defs::CONN_TX_BUF_SIZE as usize / data.len()) {
12881289
ctx.send();

src/vmm/src/devices/virtio/vsock/packet.rs

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -219,57 +219,61 @@ impl VsockPacket {
219219
///
220220
/// Return value will equal the total length of the underlying descriptor chain's buffers,
221221
/// minus the length of the vsock header.
222-
pub fn buf_size(&self) -> usize {
222+
pub fn buf_size(&self) -> u32 {
223223
let chain_length = match self.buffer {
224224
VsockPacketBuffer::Tx(ref iovec_buf) => iovec_buf.len(),
225225
VsockPacketBuffer::Rx(ref iovec_buf) => iovec_buf.len(),
226226
};
227-
(chain_length - VSOCK_PKT_HDR_SIZE) as usize
227+
chain_length - VSOCK_PKT_HDR_SIZE
228228
}
229229

230230
pub fn read_at_offset_from<T: ReadVolatile + Debug>(
231231
&mut self,
232232
src: &mut T,
233-
offset: usize,
234-
count: usize,
235-
) -> Result<usize, VsockError> {
233+
offset: u32,
234+
count: u32,
235+
) -> Result<u32, VsockError> {
236236
match self.buffer {
237237
VsockPacketBuffer::Tx(_) => Err(VsockError::UnwritableDescriptor),
238238
VsockPacketBuffer::Rx(ref mut buffer) => {
239239
if count
240-
> (buffer.len() as usize)
241-
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
240+
> buffer
241+
.len()
242+
.saturating_sub(VSOCK_PKT_HDR_SIZE)
242243
.saturating_sub(offset)
243244
{
244245
return Err(VsockError::GuestMemoryBounds);
245246
}
246247

247248
buffer
248-
.write_volatile_at(src, offset + VSOCK_PKT_HDR_SIZE as usize, count)
249+
.write_volatile_at(src, (offset + VSOCK_PKT_HDR_SIZE) as usize, count as usize)
249250
.map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err)))
251+
.and_then(|read| read.try_into().map_err(|_| VsockError::DescChainOverflow))
250252
}
251253
}
252254
}
253255

254256
pub fn write_from_offset_to<T: WriteVolatile + Debug>(
255257
&self,
256258
dst: &mut T,
257-
offset: usize,
258-
count: usize,
259-
) -> Result<usize, VsockError> {
259+
offset: u32,
260+
count: u32,
261+
) -> Result<u32, VsockError> {
260262
match self.buffer {
261263
VsockPacketBuffer::Tx(ref buffer) => {
262264
if count
263-
> (buffer.len() as usize)
264-
.saturating_sub(VSOCK_PKT_HDR_SIZE as usize)
265+
> buffer
266+
.len()
267+
.saturating_sub(VSOCK_PKT_HDR_SIZE)
265268
.saturating_sub(offset)
266269
{
267270
return Err(VsockError::GuestMemoryBounds);
268271
}
269272

270273
buffer
271-
.read_volatile_at(dst, offset + VSOCK_PKT_HDR_SIZE as usize, count)
274+
.read_volatile_at(dst, (offset + VSOCK_PKT_HDR_SIZE) as usize, count as usize)
272275
.map_err(|err| VsockError::GuestMemoryMmap(GuestMemoryError::from(err)))
276+
.and_then(|read| read.try_into().map_err(|_| VsockError::DescChainOverflow))
273277
}
274278
VsockPacketBuffer::Rx(_) => Err(VsockError::UnreadableDescriptor),
275279
}
@@ -537,10 +541,7 @@ mod tests {
537541
handler_ctx.device.queues[RXQ_INDEX].pop().unwrap(),
538542
)
539543
.unwrap();
540-
assert_eq!(
541-
pkt.buf_size(),
542-
handler_ctx.guest_rxvq.dtable[1].len.get() as usize
543-
);
544+
assert_eq!(pkt.buf_size(), handler_ctx.guest_rxvq.dtable[1].len.get());
544545
}
545546

546547
// Test case: read-only RX packet header.
@@ -646,35 +647,38 @@ mod tests {
646647
.unwrap();
647648

648649
let buf_desc = &mut handler_ctx.guest_rxvq.dtable[1];
649-
assert_eq!(pkt.buf_size(), buf_desc.len.get() as usize);
650-
let zeros = vec![0_u8; pkt.buf_size()];
650+
assert_eq!(pkt.buf_size(), buf_desc.len.get());
651+
let zeros = vec![0_u8; pkt.buf_size() as usize];
651652
let data: Vec<u8> = (0..pkt.buf_size())
652653
.map(|i| ((i as u64) & 0xff) as u8)
653654
.collect();
654655
for offset in 0..pkt.buf_size() {
655656
buf_desc.set_data(&zeros);
656657

657-
let mut expected_data = zeros[..offset].to_vec();
658-
expected_data.extend_from_slice(&data[..pkt.buf_size() - offset]);
658+
let mut expected_data = zeros[..offset as usize].to_vec();
659+
expected_data.extend_from_slice(&data[..(pkt.buf_size() - offset) as usize]);
659660

660661
pkt.read_at_offset_from(&mut data.as_slice(), offset, pkt.buf_size() - offset)
661662
.unwrap();
662663

663664
buf_desc.check_data(&expected_data);
664665

665-
let mut buf = vec![0; pkt.buf_size()];
666+
let mut buf = vec![0; pkt.buf_size() as usize];
666667
pkt2.write_from_offset_to(&mut buf.as_mut_slice(), offset, pkt.buf_size() - offset)
667668
.unwrap();
668-
assert_eq!(&buf[..pkt.buf_size() - offset], &expected_data[offset..]);
669+
assert_eq!(
670+
&buf[..(pkt.buf_size() - offset) as usize],
671+
&expected_data[offset as usize..]
672+
);
669673
}
670674

671675
let oob_cases = vec![
672676
(1, pkt.buf_size()),
673677
(pkt.buf_size(), 1),
674-
(usize::MAX, 1),
675-
(1, usize::MAX),
678+
(u32::MAX, 1),
679+
(1, u32::MAX),
676680
];
677-
let mut buf = vec![0; pkt.buf_size()];
681+
let mut buf = vec![0; pkt.buf_size() as usize];
678682
for (offset, count) in oob_cases {
679683
let res = pkt.read_at_offset_from(&mut data.as_slice(), offset, count);
680684
assert!(matches!(res, Err(VsockError::GuestMemoryBounds)));

src/vmm/src/devices/virtio/vsock/test_utils.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ impl VsockChannel for TestBackend {
6969
let buf_size = pkt.buf_size();
7070
if buf_size > 0 {
7171
let buf: Vec<u8> = (0..buf_size)
72-
.map(|i| cool_buf[i % cool_buf.len()])
72+
.map(|i| cool_buf[i as usize % cool_buf.len()])
7373
.collect();
7474
pkt.read_at_offset_from(&mut buf.as_slice(), 0, buf_size)
7575
.unwrap();
@@ -206,8 +206,8 @@ impl<'a> EventHandlerContext<'a> {
206206
}
207207

208208
#[cfg(test)]
209-
pub fn read_packet_data(pkt: &VsockPacket, how_much: usize) -> Vec<u8> {
210-
let mut buf = vec![0; how_much];
209+
pub fn read_packet_data(pkt: &VsockPacket, how_much: u32) -> Vec<u8> {
210+
let mut buf = vec![0; how_much as usize];
211211
pkt.write_from_offset_to(&mut buf.as_mut_slice(), 0, how_much)
212212
.unwrap();
213213
buf

src/vmm/src/devices/virtio/vsock/unix/muxer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -869,11 +869,11 @@ mod tests {
869869
peer_port: u32,
870870
mut data: &[u8],
871871
) -> &mut VsockPacket {
872-
assert!(data.len() <= self.tx_pkt.buf_size());
872+
assert!(data.len() <= self.tx_pkt.buf_size() as usize);
873873
self.init_tx_pkt(local_port, peer_port, uapi::VSOCK_OP_RW)
874874
.set_len(u32::try_from(data.len()).unwrap());
875875

876-
let data_len = data.len(); // store in tmp var to make borrow checker happy.
876+
let data_len = data.len().try_into().unwrap(); // store in tmp var to make borrow checker happy.
877877
self.rx_pkt
878878
.read_at_offset_from(&mut data, 0, data_len)
879879
.unwrap();

0 commit comments

Comments
 (0)