Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmbus_client: use existing mesh channels to send revoke #637

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -8025,6 +8025,7 @@ version = "0.0.0"
dependencies = [
"anyhow",
"futures",
"futures-concurrency",
"guid",
"inspect",
"mesh",
Expand Down
59 changes: 24 additions & 35 deletions vm/devices/vmbus/vmbus_client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ impl VmbusClient {
/// Creates a new instance with a receiver for incoming synic messages.
pub fn new(
synic: Arc<dyn SynicClient>,
notify_send: mesh::Sender<ClientNotification>,
offer_send: mesh::Sender<OfferInfo>,
msg_source: impl VmbusMessageSource + 'static,
spawner: &impl Spawn,
) -> Self {
Expand All @@ -113,7 +113,7 @@ impl VmbusClient {
inner,
task_recv,
running: false,
notify_send,
offer_send,
msg_source,
client_request_recv,
state: ClientState::Disconnected,
Expand Down Expand Up @@ -287,12 +287,6 @@ pub struct OfferInfo {
pub response_recv: mesh::Receiver<ChannelResponse>,
}

#[derive(Debug)]
pub enum ClientNotification {
Offer(OfferInfo),
Revoke(ChannelId),
}

#[derive(Debug)]
enum ClientRequest {
InitiateContact(Rpc<InitiateContactRequest, Result<VersionInfo, ConnectError>>),
Expand Down Expand Up @@ -495,7 +489,7 @@ struct ClientTask<T: VmbusMessageSource> {
running: bool,
modify_request: Option<Rpc<ModifyConnectionRequest, ConnectionState>>,
msg_source: T,
notify_send: mesh::Sender<ClientNotification>,
offer_send: mesh::Sender<OfferInfo>,
task_recv: mesh::Receiver<TaskRequest>,
client_request_recv: mesh::Receiver<ClientRequest>,
}
Expand Down Expand Up @@ -707,7 +701,7 @@ impl<T: VmbusMessageSource> ClientTask<T> {
if let ClientState::RequestingOffers(_, send) = &self.state {
send.send(offer_info);
} else {
self.notify_send.send(ClientNotification::Offer(offer_info));
self.offer_send.send(offer_info);
}
}
}
Expand All @@ -733,6 +727,8 @@ impl<T: VmbusMessageSource> ClientTask<T> {
.response_send
.send(ChannelResponse::TeardownGpadl(gpadl_id));
} else {
// TODO: is this really necessary? The host should have
// already unmapped all GPADLs. Remove if possible.
send_message(
self.inner.synic.as_ref(),
&protocol::GpadlTeardown {
Expand All @@ -748,16 +744,20 @@ impl<T: VmbusMessageSource> ClientTask<T> {
false
});

// Drop the channel, which will close the response channel, which will
// cause the client to know the channel has been revoked.
//
// TODO: this is wrong--client requests can still come in after this,
// and they will fail to find the channel by channel ID and panic (or
// worse, the channel ID will get reused). Either find and drop the
// associated incoming request channel here, or keep this channel object
// around until the client is done with it.
self.inner.channels.remove(&rescind.channel_id);

// Tell the host we're not referencing the client ID anymore.
self.inner.send(&protocol::RelIdReleased {
channel_id: rescind.channel_id,
});

// At this point the offer can be revoked from the relay.
self.notify_send
.send(ClientNotification::Revoke(rescind.channel_id));
}

fn handle_offers_delivered(&mut self) {
Expand Down Expand Up @@ -1549,23 +1549,19 @@ mod tests {

impl VmbusMessageSource for TestMessageSource {}

fn test_init() -> (
Arc<TestServer>,
VmbusClient,
mesh::Receiver<ClientNotification>,
) {
fn test_init() -> (Arc<TestServer>, VmbusClient, mesh::Receiver<OfferInfo>) {
let pool = DefaultPool::new();
let driver = pool.driver();
let (msg_send, msg_recv) = mesh::channel();
let server = Arc::new(TestServer {
messages: Mutex::new(Vec::new()),
send: msg_send,
});
let (notify_send, notify_recv) = mesh::channel();
let (offer_send, offer_recv) = mesh::channel();

let mut client = VmbusClient::new(
Arc::new(server.clone()),
notify_send,
offer_send,
TestMessageSource { msg_recv },
&driver,
);
Expand All @@ -1574,7 +1570,7 @@ mod tests {
.spawn(move || pool.run())
.unwrap();

(server, client, notify_recv)
(server, client, offer_recv)
}

#[async_test]
Expand Down Expand Up @@ -1997,7 +1993,7 @@ mod tests {

#[async_test]
async fn test_hot_add_remove() {
let (server, mut client, mut notify_recv) = test_init();
let (server, mut client, mut offer_recv) = test_init();

server.connect(&mut client).await;
let offer = protocol::OfferChannel {
Expand All @@ -2017,9 +2013,7 @@ mod tests {
};

server.send(in_msg(MessageType::OFFER_CHANNEL, offer));
let ClientNotification::Offer(info) = notify_recv.next().await.unwrap() else {
panic!("invalid request")
};
let mut info = offer_recv.next().await.unwrap();

assert_eq!(offer, info.offer);

Expand All @@ -2037,8 +2031,7 @@ mod tests {
})
);

let request = notify_recv.next().await.unwrap();
assert!(matches!(request, ClientNotification::Revoke(ChannelId(5))));
assert!(info.response_recv.next().await.is_none());
}

#[async_test]
Expand Down Expand Up @@ -2144,7 +2137,7 @@ mod tests {

#[async_test]
async fn test_gpadl_with_revoke() {
let (server, mut client, mut notify_recv) = test_init();
let (server, mut client, _offer_recv) = test_init();
let mut channel = server.get_channel(&mut client).await;
let channel_id = ChannelId(0);
let gpadl_id = GpadlId(1);
Expand Down Expand Up @@ -2208,11 +2201,7 @@ mod tests {
OutgoingMessage::new(&protocol::RelIdReleased { channel_id })
);

let ClientNotification::Revoke(id) = notify_recv.next().await.unwrap() else {
panic!("invalid request")
};

assert_eq!(id, channel_id);
assert!(channel.response_recv.next().await.is_none());
}

#[async_test]
Expand Down Expand Up @@ -2250,7 +2239,7 @@ mod tests {

#[async_test]
async fn test_hvsock() {
let (server, mut client, _notify_recv) = test_init();
let (server, mut client, _offer_recv) = test_init();
server.connect(&mut client).await;
let request = HvsockConnectRequest {
service_id: Guid::new_random(),
Expand Down
Loading
Loading