|
| 1 | +// SPDX-License-Identifier: GPL-2.0 |
| 2 | + |
| 3 | +use proc_macro::{Delimiter, Group, TokenStream, TokenTree}; |
| 4 | +use std::collections::HashSet; |
| 5 | +use std::fmt::Write; |
| 6 | + |
| 7 | +pub(crate) fn vtable(_attr: TokenStream, ts: TokenStream) -> TokenStream { |
| 8 | + let mut tokens: Vec<_> = ts.into_iter().collect(); |
| 9 | + |
| 10 | + // Scan for the `trait` or `impl` keyword |
| 11 | + let is_trait = tokens.iter().find_map(|token| { |
| 12 | + match token { |
| 13 | + TokenTree::Ident(ident) => match ident.to_string().as_str() { |
| 14 | + "trait" => Some(true), |
| 15 | + "impl" => Some(false), |
| 16 | + _ => None, |
| 17 | + }, |
| 18 | + _ => None, |
| 19 | + } |
| 20 | + }).expect("#[vtable] attribute should only be applied to trait or impl block"); |
| 21 | + |
| 22 | + // Retrieve the main body. The main body should be the last token tree. |
| 23 | + let body = match tokens.pop() { |
| 24 | + Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group, |
| 25 | + _ => panic!("cannot locate main body of trait or impl block"), |
| 26 | + }; |
| 27 | + |
| 28 | + let mut body_it = body.stream().into_iter(); |
| 29 | + let mut functions = Vec::new(); |
| 30 | + let mut consts = HashSet::new(); |
| 31 | + while let Some(token) = body_it.next() { |
| 32 | + match token { |
| 33 | + TokenTree::Ident(ident) if ident.to_string() == "fn" => { |
| 34 | + let fn_name = match body_it.next() { |
| 35 | + Some(TokenTree::Ident(ident)) => ident.to_string(), |
| 36 | + // Possibly we've encountered a fn pointer type instead. |
| 37 | + _ => continue, |
| 38 | + }; |
| 39 | + functions.push(fn_name); |
| 40 | + } |
| 41 | + TokenTree::Ident(ident) if ident.to_string() == "const" => { |
| 42 | + let const_name = match body_it.next() { |
| 43 | + Some(TokenTree::Ident(ident)) => ident.to_string(), |
| 44 | + // Possibly we've encountered an inline const block instead. |
| 45 | + _ => continue, |
| 46 | + }; |
| 47 | + consts.insert(const_name); |
| 48 | + } |
| 49 | + _ => (), |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + let mut const_items; |
| 54 | + if is_trait { |
| 55 | + const_items = "/// A marker to prevent implementors from forgetting to use [`#[vtable]`](vtable) attribute when implementing this trait. |
| 56 | + const USE_VTABLE_ATTR: ();".to_owned(); |
| 57 | + |
| 58 | + for f in functions { |
| 59 | + let gen_const_name = format!("HAS_{}", f.to_uppercase()); |
| 60 | + // Skip if it's declared already -- this allows user override. |
| 61 | + if consts.contains(&gen_const_name) { |
| 62 | + continue; |
| 63 | + } |
| 64 | + // We don't know on the implementation-site whether a method is required or provided |
| 65 | + // so we have to generate a const for all methods. |
| 66 | + write!( |
| 67 | + const_items, |
| 68 | + "/// Indicates if the `{f}` method is overriden by the implementor. |
| 69 | + const {gen_const_name}: bool = false;", |
| 70 | + ) |
| 71 | + .unwrap(); |
| 72 | + } |
| 73 | + } else { |
| 74 | + const_items = "const USE_VTABLE_ATTR: () = ();".to_owned(); |
| 75 | + |
| 76 | + for f in functions { |
| 77 | + let gen_const_name = format!("HAS_{}", f.to_uppercase()); |
| 78 | + if consts.contains(&gen_const_name) { |
| 79 | + continue; |
| 80 | + } |
| 81 | + write!(const_items, "const {gen_const_name}: bool = true;").unwrap(); |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + let new_body = vec![const_items.parse().unwrap(), body.stream()] |
| 86 | + .into_iter() |
| 87 | + .collect(); |
| 88 | + tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, new_body))); |
| 89 | + tokens.into_iter().collect() |
| 90 | +} |
0 commit comments