Skip to content

Commit

Permalink
use safe match pattern
Browse files Browse the repository at this point in the history
  • Loading branch information
romancardenas committed Jun 5, 2024
1 parent ff22137 commit f9c94b5
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 115 deletions.
115 changes: 38 additions & 77 deletions riscv-pac/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,66 @@ extern crate syn;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use std::{collections::HashMap, ops::Range, str::FromStr};
use std::str::FromStr;
use syn::{parse_macro_input, Data, DeriveInput, Ident};

struct PacNumberEnum {
name: Ident,
valid_ranges: Vec<Range<usize>>,
numbers: Vec<(Ident, usize)>,
}

impl PacNumberEnum {
fn new(input: &DeriveInput) -> Self {
let name = input.ident.clone();

let variants = match &input.data {
Data::Enum(data) => &data.variants,
_ => panic!("Input is not an enum"),
};

// Collect the variants and their associated number discriminants
let mut var_map = HashMap::new();
let mut numbers = Vec::new();
for variant in variants {
let ident = &variant.ident;
let value = match &variant.discriminant {
Some(d) => match &d.1 {
syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
syn::Lit::Int(lit_int) => match lit_int.base10_parse::<usize>() {
Ok(num) => num,
Err(_) => panic!("All variant discriminants must be unsigned integers"),
let numbers = variants
.iter()
.map(|variant| {
let ident = &variant.ident;
let value = match &variant.discriminant {
Some(d) => match &d.1 {
syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
syn::Lit::Int(lit_int) => match lit_int.base10_parse::<usize>() {
Ok(num) => num,
Err(_) => {
panic!("All variant discriminants must be unsigned integers")
}
},
_ => panic!("All variant discriminants must be unsigned integers"),
},
_ => panic!("All variant discriminants must be unsigned integers"),
},
_ => panic!("All variant discriminants must be unsigned integers"),
},
_ => panic!("Variant must have a discriminant"),
};
var_map.insert(value, ident);
numbers.push(value);
}
_ => panic!("Variant must have a discriminant"),
};
(ident.clone(), value)
})
.collect();

// sort the number discriminants and generate a list of valid ranges
numbers.sort_unstable();
let mut valid_ranges = Vec::new();
let mut start = numbers[0];
let mut end = start;
for &number in &numbers[1..] {
if number == end + 1 {
end = number;
} else {
valid_ranges.push(start..end + 1);
start = number;
end = start;
}
}
valid_ranges.push(start..end + 1);

Self {
name: input.ident.clone(),
valid_ranges,
}
}

fn valid_condition(&self) -> TokenStream2 {
let mut arms = Vec::new();
for range in &self.valid_ranges {
let (start, end) = (range.start, range.end);
if end - start == 1 {
arms.push(TokenStream2::from_str(&format!("number == {start}")).unwrap());
} else {
arms.push(
TokenStream2::from_str(&format!("({start}..{end}).contains(&number)")).unwrap(),
);
}
}
quote! { #(#arms) || * }
Self { name, numbers }
}

fn max_discriminant(&self) -> TokenStream2 {
let max_discriminant = self.valid_ranges.last().expect("invalid range").end - 1;
let max_discriminant = self.numbers.iter().map(|(_, num)| num).max().unwrap();
TokenStream2::from_str(&format!("{max_discriminant}")).unwrap()
}

fn valid_matches(&self) -> Vec<TokenStream2> {
self.numbers
.iter()
.map(|(ident, num)| {
TokenStream2::from_str(&format!("{num} => Ok(Self::{ident})")).unwrap()
})
.collect()
}

fn quote(&self, trait_name: &str, num_type: &str, const_name: &str) -> TokenStream2 {
let name = &self.name;
let max_discriminant = self.max_discriminant();
let valid_condition = self.valid_condition();
let valid_matches = self.valid_matches();

let trait_name = TokenStream2::from_str(trait_name).unwrap();
let num_type = TokenStream2::from_str(num_type).unwrap();
Expand All @@ -105,11 +82,9 @@ impl PacNumberEnum {

#[inline]
fn from_number(number: #num_type) -> Result<Self, #num_type> {
if #valid_condition {
// SAFETY: The number is valid for this enum
Ok(unsafe { core::mem::transmute::<#num_type, Self>(number) })
} else {
Err(number)
match number {
#(#valid_matches,)*
_ => Err(number),
}
}
}
Expand All @@ -125,20 +100,6 @@ impl PacNumberEnum {
/// The trait name must be one of `ExceptionNumber`, `InterruptNumber`, `PriorityNumber`, or `HartIdNumber`.
/// Marker traits `CoreInterruptNumber` and `ExternalInterruptNumber` cannot be implemented using this macro.
///
/// # Note
///
/// To implement number-to-enum operation, the macro works with ranges of valid discriminant numbers.
/// If the number is within any of the valid ranges, the number is transmuted to the enum variant.
/// In this way, the macro achieves better performance for enums with a large number of consecutive variants.
/// Thus, the enum must comply with the following requirements:
///
/// - All the enum variants must have a valid discriminant number (i.e., a number that is within the valid range of the enum).
/// - For the `ExceptionNumber`, `InterruptNumber`, and `HartIdNumber` traits, the enum must be annotated as `#[repr(u16)]`
/// - For the `PriorityNumber` trait, the enum must be annotated as `#[repr(u8)]`
///
/// If the enum does not meet these requirements, you will have to implement the traits manually (e.g., `riscv::mcause::Interrupt`).
/// For enums with a small number of consecutive variants, it might be better to implement the traits manually.
///
/// # Safety
///
/// The struct to be implemented must comply with the requirements of the specified trait.
Expand Down
9 changes: 0 additions & 9 deletions riscv-pac/tests/ui/fail_wrong_repr.rs

This file was deleted.

9 changes: 0 additions & 9 deletions riscv-pac/tests/ui/fail_wrong_repr.stderr

This file was deleted.

34 changes: 24 additions & 10 deletions riscv/src/register/mcause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ unsafe impl InterruptNumber for Interrupt {

#[inline]
fn from_number(value: usize) -> Result<Self, usize> {
if value > 11 || value % 2 == 0 {
Err(value)
} else {
// SAFETY: valid interrupt number
unsafe { Ok(core::mem::transmute::<usize, Self>(value)) }
match value {
1 => Ok(Self::SupervisorSoft),
3 => Ok(Self::MachineSoft),
5 => Ok(Self::SupervisorTimer),
7 => Ok(Self::MachineTimer),
9 => Ok(Self::SupervisorExternal),
11 => Ok(Self::MachineExternal),
_ => Err(value),
}
}
}
Expand Down Expand Up @@ -69,11 +72,22 @@ unsafe impl ExceptionNumber for Exception {

#[inline]
fn from_number(value: usize) -> Result<Self, usize> {
if value == 10 || value == 14 || value > 15 {
Err(value)
} else {
// SAFETY: valid exception number
unsafe { Ok(core::mem::transmute::<usize, Self>(value)) }
match value {
0 => Ok(Self::InstructionMisaligned),
1 => Ok(Self::InstructionFault),
2 => Ok(Self::IllegalInstruction),
3 => Ok(Self::Breakpoint),
4 => Ok(Self::LoadMisaligned),
5 => Ok(Self::LoadFault),
6 => Ok(Self::StoreMisaligned),
7 => Ok(Self::StoreFault),
8 => Ok(Self::UserEnvCall),
9 => Ok(Self::SupervisorEnvCall),
11 => Ok(Self::MachineEnvCall),
12 => Ok(Self::InstructionPageFault),
13 => Ok(Self::LoadPageFault),
15 => Ok(Self::StorePageFault),
_ => Err(value),
}
}
}
Expand Down
30 changes: 20 additions & 10 deletions riscv/src/register/scause.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ unsafe impl InterruptNumber for Interrupt {

#[inline]
fn from_number(value: usize) -> Result<Self, usize> {
if value == 1 || value == 5 || value == 9 {
// SAFETY: valid interrupt number
Ok(unsafe { core::mem::transmute::<usize, Self>(value) })
} else {
Err(value)
match value {
1 => Ok(Self::SupervisorSoft),
5 => Ok(Self::SupervisorTimer),
9 => Ok(Self::SupervisorExternal),
_ => Err(value),
}
}
}
Expand Down Expand Up @@ -65,11 +65,21 @@ unsafe impl ExceptionNumber for Exception {

#[inline]
fn from_number(value: usize) -> Result<Self, usize> {
if value == 10 || value == 11 || value == 14 || value > 15 {
Err(value)
} else {
// SAFETY: valid exception number
unsafe { Ok(core::mem::transmute::<usize, Self>(value)) }
match value {
0 => Ok(Self::InstructionMisaligned),
1 => Ok(Self::InstructionFault),
2 => Ok(Self::IllegalInstruction),
3 => Ok(Self::Breakpoint),
4 => Ok(Self::LoadMisaligned),
5 => Ok(Self::LoadFault),
6 => Ok(Self::StoreMisaligned),
7 => Ok(Self::StoreFault),
8 => Ok(Self::UserEnvCall),
9 => Ok(Self::SupervisorEnvCall),
12 => Ok(Self::InstructionPageFault),
13 => Ok(Self::LoadPageFault),
15 => Ok(Self::StorePageFault),
_ => Err(value),
}
}
}
Expand Down

0 comments on commit f9c94b5

Please sign in to comment.