Skip to content

Commit 2627e19

Browse files
committed
sound: add QueueIdx enum for virtio queue indices
Add type safe enum to use instead of raw u16 values, which we have to validate every time we use them. Signed-off-by: Manos Pitsidianakis <[email protected]>
1 parent 87aac39 commit 2627e19

File tree

2 files changed

+83
-27
lines changed

2 files changed

+83
-27
lines changed

staging/vhost-device-sound/src/device.rs

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,15 @@ use crate::{
3030
audio_backends::{alloc_audio_backend, AudioBackend},
3131
stream::{Buffer, Error as StreamError, Stream},
3232
virtio_sound::*,
33-
ControlMessageKind, Direction, Error, IOMessage, Result, SoundConfig,
33+
ControlMessageKind, Direction, Error, IOMessage, QueueIdx, Result, SoundConfig,
3434
};
3535

3636
pub struct VhostUserSoundThread {
3737
mem: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
3838
event_idx: bool,
3939
chmaps: Arc<RwLock<Vec<VirtioSoundChmapInfo>>>,
4040
jacks: Arc<RwLock<Vec<VirtioSoundJackInfo>>>,
41-
queue_indexes: Vec<u16>,
41+
queue_indexes: Vec<QueueIdx>,
4242
streams: Arc<RwLock<Vec<Stream>>>,
4343
streams_no: usize,
4444
}
@@ -49,11 +49,11 @@ impl VhostUserSoundThread {
4949
pub fn new(
5050
chmaps: Arc<RwLock<Vec<VirtioSoundChmapInfo>>>,
5151
jacks: Arc<RwLock<Vec<VirtioSoundJackInfo>>>,
52-
mut queue_indexes: Vec<u16>,
52+
mut queue_indexes: Vec<QueueIdx>,
5353
streams: Arc<RwLock<Vec<Stream>>>,
5454
streams_no: usize,
5555
) -> Result<Self> {
56-
queue_indexes.sort();
56+
queue_indexes.sort_by_key(|idx| *idx as u16);
5757

5858
Ok(Self {
5959
event_idx: false,
@@ -70,7 +70,7 @@ impl VhostUserSoundThread {
7070
let mut queues_per_thread = 0u64;
7171

7272
for idx in self.queue_indexes.iter() {
73-
queues_per_thread |= 1u64 << idx
73+
queues_per_thread |= 1u64 << *idx as u16
7474
}
7575

7676
queues_per_thread
@@ -94,7 +94,10 @@ impl VhostUserSoundThread {
9494
let vring = &vrings
9595
.get(device_event as usize)
9696
.ok_or_else(|| Error::HandleUnknownEvent(device_event))?;
97-
let queue_idx = self.queue_indexes[device_event as usize];
97+
let queue_idx = self
98+
.queue_indexes
99+
.get(device_event as usize)
100+
.ok_or_else(|| Error::HandleUnknownEvent(device_event))?;
98101
if self.event_idx {
99102
// vm-virtio's Queue implementation only checks avail_index
100103
// once, so to properly support EVENT_IDX we need to keep
@@ -103,11 +106,10 @@ impl VhostUserSoundThread {
103106
loop {
104107
vring.disable_notification().unwrap();
105108
match queue_idx {
106-
CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend),
107-
EVENT_QUEUE_IDX => self.process_event(vring),
108-
TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output),
109-
RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input),
110-
_ => Err(Error::HandleUnknownEvent(queue_idx).into()),
109+
QueueIdx::Control => self.process_control(vring, audio_backend),
110+
QueueIdx::Event => self.process_event(vring),
111+
QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output),
112+
QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input),
111113
}?;
112114
if !vring.enable_notification().unwrap() {
113115
break;
@@ -116,11 +118,10 @@ impl VhostUserSoundThread {
116118
} else {
117119
// Without EVENT_IDX, a single call is enough.
118120
match queue_idx {
119-
CONTROL_QUEUE_IDX => self.process_control(vring, audio_backend),
120-
EVENT_QUEUE_IDX => self.process_event(vring),
121-
TX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Output),
122-
RX_QUEUE_IDX => self.process_io(vring, audio_backend, Direction::Input),
123-
_ => Err(Error::HandleUnknownEvent(queue_idx).into()),
121+
QueueIdx::Control => self.process_control(vring, audio_backend),
122+
QueueIdx::Event => self.process_event(vring),
123+
QueueIdx::Tx => self.process_io(vring, audio_backend, Direction::Output),
124+
QueueIdx::Rx => self.process_io(vring, audio_backend, Direction::Input),
124125
}?;
125126
}
126127
Ok(())
@@ -635,21 +636,21 @@ impl VhostUserSoundBackend {
635636
RwLock::new(VhostUserSoundThread::new(
636637
chmaps.clone(),
637638
jacks.clone(),
638-
vec![CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX],
639+
vec![QueueIdx::Control, QueueIdx::Event],
639640
streams.clone(),
640641
streams_no,
641642
)?),
642643
RwLock::new(VhostUserSoundThread::new(
643644
chmaps.clone(),
644645
jacks.clone(),
645-
vec![TX_QUEUE_IDX],
646+
vec![QueueIdx::Tx],
646647
streams.clone(),
647648
streams_no,
648649
)?),
649650
RwLock::new(VhostUserSoundThread::new(
650651
chmaps,
651652
jacks,
652-
vec![RX_QUEUE_IDX],
653+
vec![QueueIdx::Rx],
653654
streams.clone(),
654655
streams_no,
655656
)?),
@@ -659,10 +660,10 @@ impl VhostUserSoundBackend {
659660
chmaps,
660661
jacks,
661662
vec![
662-
CONTROL_QUEUE_IDX,
663-
EVENT_QUEUE_IDX,
664-
TX_QUEUE_IDX,
665-
RX_QUEUE_IDX,
663+
QueueIdx::Control,
664+
QueueIdx::Event,
665+
QueueIdx::Tx,
666+
QueueIdx::Rx,
666667
],
667668
streams.clone(),
668669
streams_no,
@@ -832,7 +833,7 @@ mod tests {
832833

833834
let chmaps = Arc::new(RwLock::new(vec![]));
834835
let jacks = Arc::new(RwLock::new(vec![]));
835-
let queue_indexes = vec![1, 2, 3];
836+
let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx];
836837
let streams = vec![Stream::default()];
837838
let streams_no = streams.len();
838839
let streams = Arc::new(RwLock::new(streams));
@@ -927,7 +928,7 @@ mod tests {
927928

928929
let chmaps = Arc::new(RwLock::new(vec![]));
929930
let jacks = Arc::new(RwLock::new(vec![]));
930-
let queue_indexes = vec![1, 2, 3];
931+
let queue_indexes = vec![QueueIdx::Event, QueueIdx::Tx, QueueIdx::Rx];
931932
let streams = Arc::new(RwLock::new(vec![]));
932933
let streams_no = 0;
933934
let thread =

staging/vhost-device-sound/src/lib.rs

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,43 @@ impl TryFrom<u8> for Direction {
9494
Ok(match val {
9595
virtio_sound::VIRTIO_SND_D_OUTPUT => Self::Output,
9696
virtio_sound::VIRTIO_SND_D_INPUT => Self::Input,
97-
other => return Err(Error::InvalidMessageValue(stringify!(Direction), other)),
97+
other => {
98+
return Err(Error::InvalidMessageValue(
99+
stringify!(Direction),
100+
other.into(),
101+
))
102+
}
103+
})
104+
}
105+
}
106+
107+
/// Queue index.
108+
///
109+
/// Type safe enum for CONTROL_QUEUE_IDX, EVENT_QUEUE_IDX, TX_QUEUE_IDX,
110+
/// RX_QUEUE_IDX.
111+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112+
#[repr(u16)]
113+
pub enum QueueIdx {
114+
#[doc(alias = "CONTROL_QUEUE_IDX")]
115+
Control = virtio_sound::CONTROL_QUEUE_IDX,
116+
#[doc(alias = "EVENT_QUEUE_IDX")]
117+
Event = virtio_sound::EVENT_QUEUE_IDX,
118+
#[doc(alias = "TX_QUEUE_IDX")]
119+
Tx = virtio_sound::TX_QUEUE_IDX,
120+
#[doc(alias = "RX_QUEUE_IDX")]
121+
Rx = virtio_sound::RX_QUEUE_IDX,
122+
}
123+
124+
impl TryFrom<u16> for QueueIdx {
125+
type Error = Error;
126+
127+
fn try_from(val: u16) -> std::result::Result<Self, Self::Error> {
128+
Ok(match val {
129+
virtio_sound::CONTROL_QUEUE_IDX => Self::Control,
130+
virtio_sound::EVENT_QUEUE_IDX => Self::Event,
131+
virtio_sound::TX_QUEUE_IDX => Self::Tx,
132+
virtio_sound::RX_QUEUE_IDX => Self::Rx,
133+
other => return Err(Error::InvalidMessageValue(stringify!(QueueIdx), other)),
98134
})
99135
}
100136
}
@@ -117,7 +153,7 @@ pub enum Error {
117153
#[error("Invalid control message code {0}")]
118154
InvalidControlMessage(u32),
119155
#[error("Invalid value in {0}: {1}")]
120-
InvalidMessageValue(&'static str, u8),
156+
InvalidMessageValue(&'static str, u16),
121157
#[error("Failed to create a new EventFd")]
122158
EventFdCreate(IoError),
123159
#[error("Request missing data buffer")]
@@ -389,6 +425,25 @@ mod tests {
389425

390426
let val = 42;
391427
Direction::try_from(val).unwrap_err();
428+
429+
assert_eq!(
430+
QueueIdx::try_from(virtio_sound::CONTROL_QUEUE_IDX).unwrap(),
431+
QueueIdx::Control
432+
);
433+
assert_eq!(
434+
QueueIdx::try_from(virtio_sound::EVENT_QUEUE_IDX).unwrap(),
435+
QueueIdx::Event
436+
);
437+
assert_eq!(
438+
QueueIdx::try_from(virtio_sound::TX_QUEUE_IDX).unwrap(),
439+
QueueIdx::Tx
440+
);
441+
assert_eq!(
442+
QueueIdx::try_from(virtio_sound::RX_QUEUE_IDX).unwrap(),
443+
QueueIdx::Rx
444+
);
445+
let val = virtio_sound::NUM_QUEUES;
446+
QueueIdx::try_from(val).unwrap_err();
392447
}
393448

394449
#[test]

0 commit comments

Comments
 (0)