Skip to content

Add ability to resolve (and cache) multicast groups for a given family name #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ use crate::{
resolver::Resolver,
};
use futures::{lock::Mutex, Stream, StreamExt};
use log::trace;
use netlink_packet_core::{NetlinkMessage, NetlinkPayload};
use netlink_packet_generic::{GenlFamily, GenlHeader, GenlMessage};
use netlink_packet_utils::{DecodeError, Emitable, ParseableParametrized};
use netlink_proto::{sys::SocketAddr, ConnectionHandle};
use std::{fmt::Debug, sync::Arc};
use std::{collections::HashMap, fmt::Debug, sync::Arc};

/// The generic netlink connection handle
///
Expand Down Expand Up @@ -67,6 +68,19 @@ impl GenetlinkHandle {
.await
}

/// Resolve the multicast groups of the given [`GenlFamily`].
pub async fn resolve_mcast_groups<F>(&self) -> Result<HashMap<String, u32>, GenetlinkError>
where
F: GenlFamily,
{
trace!("Requesting Groups from Resolver: {:?}", F::family_name());
self.resolver
.lock()
.await
.query_family_multicast_groups(self, F::family_name())
.await
}

/// Clear the resolver's fanily id cache
pub async fn clear_family_id_cache(&self) {
self.resolver.lock().await.clear_cache();
Expand Down
139 changes: 138 additions & 1 deletion src/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,36 @@

use crate::{error::GenetlinkError, GenetlinkHandle};
use futures::{future::Either, StreamExt};
use log::{error, trace, warn};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding

#[macro_use]
extern crate log;

in lib.rs is acceptable.
Since it is widely used in many Rust project today.

Copy link
Author

@Ragnt Ragnt Nov 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't realized I added the logging, obviously it was for debugging so it doesn't matter to me either way.

use netlink_packet_core::{NetlinkMessage, NetlinkPayload, NLM_F_REQUEST};
use netlink_packet_generic::{
ctrl::{nlas::GenlCtrlAttrs, GenlCtrl, GenlCtrlCmd},
ctrl::{nlas::{GenlCtrlAttrs, McastGrpAttrs}, GenlCtrl, GenlCtrlCmd},
GenlMessage,
};
use std::{collections::HashMap, future::Future};

#[derive(Clone, Debug, Default)]
pub struct Resolver {
cache: HashMap<&'static str, u16>,
groups_cache: HashMap<&'static str, HashMap<String, u32>>
}

impl Resolver {
pub fn new() -> Self {
Self {
cache: HashMap::new(),
groups_cache: HashMap::new(),
}
}

pub fn get_cache_by_name(&self, family_name: &str) -> Option<u16> {
self.cache.get(family_name).copied()
}

pub fn get_groups_cache_by_name(&self, family_name: &str) -> Option<HashMap<String, u32>> {
self.groups_cache.get(family_name).cloned()
}

pub fn query_family_id(
&mut self,
handle: &GenetlinkHandle,
Expand Down Expand Up @@ -85,9 +92,112 @@ impl Resolver {
}
}

pub fn query_family_multicast_groups(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functions always sends out a netlink request compared to query_family_id. I don't have a preference, but wanted to point out this inconsistency.

&mut self,
handle: &GenetlinkHandle,
family_name: &'static str,
) -> impl Future<Output = Result<HashMap<String, u32>, GenetlinkError>> + '_ {
let mut handle = handle.clone();
async move {
trace!("Starting query_family_multicast_groups for family_name: '{}'", family_name);

// First, get the family ID (this uses your existing method)
trace!("Calling query_family_id for family_name: '{}'", family_name);
let family_id = self.query_family_id(&handle, family_name).await?;
trace!("Received family_id: {}", family_id);

// Create the request message to get family details
trace!("Creating GenlMessage for CTRL_CMD_GETFAMILY");
let mut genlmsg: GenlMessage<GenlCtrl> = GenlMessage::from_payload(GenlCtrl {
cmd: GenlCtrlCmd::GetFamily,
nlas: vec![GenlCtrlAttrs::FamilyId(family_id)],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also work with GenlCtrlAttrs::FamilyName so you don't have to resolve the family id here.

});
genlmsg.finalize();
let mut nlmsg = NetlinkMessage::from(genlmsg);
nlmsg.header.flags = NLM_F_REQUEST;
nlmsg.finalize();
trace!("NetlinkMessage created: {:?}", nlmsg);

// Send the request
trace!("Sending NetlinkMessage to netlink socket");
let mut res = handle.send_request(nlmsg)?;
trace!("Request sent, awaiting response");

// Prepare to collect multicast groups
let mut mc_groups = HashMap::new();

// Process the response
trace!("Processing responses");
while let Some(result) = res.next().await {
trace!("Received a response");
let rx_packet = result?;
trace!("Received NetlinkMessage: {:?}", rx_packet);
match rx_packet.payload {
NetlinkPayload::InnerMessage(genlmsg) => {
trace!("Processing InnerMessage: {:?}", genlmsg);
for nla in genlmsg.payload.nlas {
trace!("Processing NLA: {:?}", nla);
if let GenlCtrlAttrs::McastGroups(groups) = nla {
trace!("Found McastGroups: {:?}", groups);
for group in groups {
// 'group' is a Vec<McastGrpAttrs>
let mut group_name = None;
let mut group_id = None;

for group_attr in group {
trace!("Processing group_attr: {:?}", group_attr);
match group_attr {
McastGrpAttrs::Name(ref name) => {
group_name = Some(name.clone());
trace!("Found group name: '{}'", name);
}
McastGrpAttrs::Id(id) => {
group_id = Some(id);
trace!("Found group id: {}", id);
}
}
}

if let (Some(name), Some(id)) = (group_name, group_id) {
mc_groups.insert(name.clone(), id);
trace!(
"Inserted group '{}' with id {} into mc_groups",
name,
id
);
}
}
} else {
trace!("Unhandled NLA: {:?}", nla);
}
}
Comment on lines +137 to +173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could probably be written a little more concise:

Suggested change
trace!("Processing InnerMessage: {:?}", genlmsg);
for nla in genlmsg.payload.nlas {
trace!("Processing NLA: {:?}", nla);
if let GenlCtrlAttrs::McastGroups(groups) = nla {
trace!("Found McastGroups: {:?}", groups);
for group in groups {
// 'group' is a Vec<McastGrpAttrs>
let mut group_name = None;
let mut group_id = None;
for group_attr in group {
trace!("Processing group_attr: {:?}", group_attr);
match group_attr {
McastGrpAttrs::Name(ref name) => {
group_name = Some(name.clone());
trace!("Found group name: '{}'", name);
}
McastGrpAttrs::Id(id) => {
group_id = Some(id);
trace!("Found group id: {}", id);
}
}
}
if let (Some(name), Some(id)) = (group_name, group_id) {
mc_groups.insert(name.clone(), id);
trace!(
"Inserted group '{}' with id {} into mc_groups",
name,
id
);
}
}
} else {
trace!("Unhandled NLA: {:?}", nla);
}
}
// One specific family id was requested, it can be assumed, that the mcast
// groups are part of that family.
let Some(mcast_groups) = genlmsg
.payload
.nlas
.into_iter()
.filter_map(|attr| match attr {
GenlCtrlAttrs::McastGroups(groups) => {
Some(groups)
}
_ => None,
})
.next()
else {
continue;
};
for group in mcast_groups.into_iter().filter_map(|attrs| {
match attrs.as_slice() {
[McastGrpAttrs::Name(name), McastGrpAttrs::Id(i)] |
[McastGrpAttrs::Id(i), McastGrpAttrs::Name(name)] => Some((name.clone(), *i)),
_ => None
}
}) {
mc_groups.insert(group.0, group.1);
}

}
NetlinkPayload::Error(e) => {
error!("Received NetlinkPayload::Error: {:?}", e);
return Err(e.into());
}
other => {
warn!("Received unexpected NetlinkPayload: {:?}", other);
}
}
}
trace!("Finished processing responses");

// Update the cache
self.groups_cache.insert(family_name, mc_groups.clone());
trace!("Updated groups_cache for family_name: '{}'", family_name);

trace!("Returning mc_groups: {:?}", mc_groups);
Ok(mc_groups)
}
}


pub fn clear_cache(&mut self) {
self.cache.clear();
self.groups_cache.clear();
}

}

#[cfg(all(test, tokio_socket))]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests are actually never run and always disabled ...

Suggested change
#[cfg(all(test, tokio_socket))]
#[cfg(all(test, feature = "tokio_socket"))]

Expand Down Expand Up @@ -152,6 +262,33 @@ mod test {

let cache = resolver.get_cache_by_name(name).unwrap();
assert_eq!(id, cache);

let mcast_groups = resolver
.query_family_multicast_groups(&handle, name)
.await
.or_else(|e| {
if let GenetlinkError::NetlinkError(io_err) = &e {
if io_err.kind() == ErrorKind::NotFound {
// Ignore non exist entries
Ok(0)
} else {
Err(e)
}
} else {
Err(e)
}
})
.unwrap();
if mcast_groups.is_empty() {
log::warn!(
"Generic family \"{name}\" not exist or not loaded \
in this environment. Ignored."
);
continue;
}

let cache = resolver.get_groups_cache_by_name(name).unwrap();
assert_eq!(mcast_groups, cache);
log::warn!("{:?}", (name, cache));
}
}
Expand Down