Skip to content

Commit 45d4e5c

Browse files
authored
fiat-constify: add support for leveraging const_mut_refs (#1177)
The generated code is *almost* `const fn` friendly with `const_mut_refs` but needs `const fn` accessors for the newtypes it defines: mit-plv/fiat-crypto#1955 This updates `fiat-constify` to solve only adding such accessors and annotating the functions as `const fn`, greatly reducing its scope and making the resulting postprocessed code much closer to the original generated code. Additionally it removes the extra methods added to the code, namely `as_inner` and `into_inner`, in order to ensure that the output is as close to the original generated code as possible.
1 parent b89a4d9 commit 45d4e5c

File tree

3 files changed

+22
-284
lines changed

3 files changed

+22
-284
lines changed

fiat-constify/src/main.rs

+20-205
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,12 @@
55
66
#![allow(clippy::single_match, clippy::new_without_default)]
77

8-
mod outputs;
98
mod type_registry;
109

11-
use outputs::Outputs;
1210
use proc_macro2::{Punct, Spacing, Span};
13-
use quote::{TokenStreamExt, quote};
14-
use std::{collections::BTreeMap as Map, env, fs, ops::Deref};
15-
use syn::{
16-
Expr, ExprCall, ExprPath, ExprReference, Fields, FnArg, Ident, Item, ItemFn, Local, LocalInit,
17-
Meta, Pat, PatIdent, PatTuple, Path, Stmt, TypeReference, parse_quote,
18-
punctuated::Punctuated,
19-
token::{Const, Eq, Let, Paren, Semi},
20-
};
11+
use quote::TokenStreamExt;
12+
use std::{env, fs, ops::Deref};
13+
use syn::{FnArg, Ident, Item, ItemFn, Meta, Pat, Stmt, TypeReference, parse_quote, token::Const};
2114
use type_registry::TypeRegistry;
2215

2316
fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -32,19 +25,15 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
3225
ast.attrs.push(parse_quote! {
3326
#![allow(
3427
clippy::identity_op,
28+
clippy::too_many_arguments,
3529
clippy::unnecessary_cast,
36-
dead_code,
37-
rustdoc::broken_intra_doc_links,
38-
unused_assignments,
39-
unused_mut,
40-
unused_variables
30+
dead_code
4131
)]
4232
});
4333

4434
let mut type_registry = TypeRegistry::new();
4535

4636
// Iterate over functions, transforming them into `const fn`
47-
let mut const_deref = Vec::new();
4837
for item in &mut ast.items {
4938
match item {
5039
Item::Fn(func) => rewrite_fn_as_const(func, &type_registry),
@@ -67,32 +56,11 @@ fn main() -> Result<(), Box<dyn std::error::Error>> {
6756
});
6857
}
6958

70-
let ident = &ty.ident;
71-
if let Fields::Unnamed(unnamed) = &ty.fields {
72-
if let Some(unit) = unnamed.unnamed.first() {
73-
let unit_ty = &unit.ty;
74-
const_deref.push(parse_quote! {
75-
impl #ident {
76-
#[inline]
77-
pub const fn as_inner(&self) -> &#unit_ty {
78-
&self.0
79-
}
80-
81-
#[inline]
82-
pub const fn into_inner(self) -> #unit_ty {
83-
self.0
84-
}
85-
}
86-
});
87-
}
88-
}
89-
90-
type_registry.add_new_type(ty)
59+
type_registry.add_newtype(ty)
9160
}
9261
_ => (),
9362
}
9463
}
95-
ast.items.extend_from_slice(&const_deref);
9664

9765
println!(
9866
"//! fiat-crypto output postprocessed by fiat-constify: <https://github.com/rustcrypto/utils>"
@@ -116,8 +84,6 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) {
11684
func.sig.constness = Some(Const::default());
11785

11886
// Transform mutable arguments into return values.
119-
let mut inputs = Punctuated::new();
120-
let mut outputs = Outputs::new(type_registry);
12187
let mut stmts = Vec::<Stmt>::new();
12288

12389
for arg in &func.sig.inputs {
@@ -129,8 +95,17 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) {
12995
elem,
13096
..
13197
}) => {
132-
outputs.add(get_ident_from_pat(&t.pat), elem.deref().clone());
133-
continue;
98+
if matches!(elem.deref(), syn::Type::Path(_)) {
99+
// Generation of reborrows, LLVM should optimize this out, and it definitely
100+
// will if `#[repr(transparent)]` is used.
101+
let ty = type_registry::type_to_ident(elem).unwrap();
102+
let ident = get_ident_from_pat(&t.pat);
103+
if type_registry.is_newtype(ty) {
104+
stmts.push(parse_quote! {
105+
let #ident = &mut #ident.0;
106+
});
107+
}
108+
}
134109
}
135110
syn::Type::Reference(TypeReference {
136111
mutability: None,
@@ -141,177 +116,17 @@ fn rewrite_fn_as_const(func: &mut ItemFn, type_registry: &TypeRegistry) {
141116
// will if `#[repr(transparent)]` is used.
142117
let ty = type_registry::type_to_ident(elem).unwrap();
143118
let ident = get_ident_from_pat(&t.pat);
144-
if outputs.type_registry().is_new_type(ty) {
119+
if type_registry.is_newtype(ty) {
145120
stmts.push(parse_quote! {
146-
let #ident = #ident.as_inner();
121+
let #ident = &#ident.0;
147122
});
148123
}
149124
}
150125
_ => (),
151126
}
152127
}
153-
154-
// If the argument wasn't a mutable reference, add it as an input.
155-
inputs.push(arg.clone());
156128
}
157129

158-
// Replace inputs with ones where the mutable references have been filtered out
159-
func.sig.inputs = inputs;
160-
func.sig.output = outputs.to_return_type();
161-
stmts.extend(rewrite_fn_body(&func.block.stmts, &outputs));
130+
stmts.extend(func.block.stmts.clone());
162131
func.block.stmts = stmts;
163132
}
164-
165-
/// Rewrite the function body, adding let bindings with `Default::default()`
166-
/// values for outputs, removing mutable references, and adding a return
167-
/// value/tuple.
168-
fn rewrite_fn_body(stmts: &[Stmt], outputs: &Outputs) -> Vec<Stmt> {
169-
let mut ident_assignments: Map<&Ident, Vec<&Expr>> = Map::new();
170-
let mut rewritten = Vec::new();
171-
172-
for stmt in stmts {
173-
if let Stmt::Expr(Expr::Assign(assignment), Some(_)) = stmt {
174-
let lhs_path = match assignment.left.as_ref() {
175-
Expr::Unary(lhs) => {
176-
if let Expr::Path(exprpath) = lhs.expr.as_ref() {
177-
Some(exprpath)
178-
} else {
179-
panic!("All unary exprpaths should have the LHS as the path");
180-
}
181-
}
182-
Expr::Index(lhs) => {
183-
if let Expr::Path(exprpath) = lhs.expr.as_ref() {
184-
Some(exprpath)
185-
} else {
186-
panic!("All unary exprpaths should have the LHS as the path");
187-
}
188-
}
189-
Expr::Call(expr) => {
190-
rewritten.push(Stmt::Local(rewrite_fn_call(expr.clone())));
191-
None
192-
}
193-
_ => None,
194-
};
195-
if let Some(lhs_path) = lhs_path {
196-
ident_assignments
197-
.entry(Path::get_ident(&lhs_path.path).unwrap())
198-
.or_default()
199-
.push(&assignment.right);
200-
}
201-
} else if let Stmt::Expr(Expr::Call(expr), Some(_)) = stmt {
202-
rewritten.push(Stmt::Local(rewrite_fn_call(expr.clone())));
203-
} else if let Stmt::Local(Local {
204-
pat: Pat::Type(pat),
205-
..
206-
}) = stmt
207-
{
208-
let unboxed = pat.pat.as_ref();
209-
if let Pat::Ident(PatIdent {
210-
mutability: Some(_),
211-
..
212-
}) = unboxed
213-
{
214-
// This is a mut var, in the case of fiat-crypto transformation dead code
215-
} else {
216-
rewritten.push(stmt.clone());
217-
}
218-
} else {
219-
rewritten.push(stmt.clone());
220-
}
221-
}
222-
223-
let mut asts = Vec::new();
224-
for (ident, ty) in outputs.ident_type_pairs() {
225-
let value = ident_assignments.get(ident).unwrap();
226-
let type_prefix = match type_registry::type_to_ident(ty) {
227-
Some(ident) if outputs.type_registry().is_new_type(ident) => Some(ty),
228-
_ => None,
229-
};
230-
231-
let ast = match (type_prefix, value.len()) {
232-
(None, 1) => {
233-
let first = value.first().unwrap();
234-
quote!(#first)
235-
}
236-
(Some(prefix), 1) => {
237-
let first = value.first().unwrap();
238-
quote!(#prefix(#first))
239-
}
240-
241-
(None, _) => {
242-
quote!([#(#value),*])
243-
}
244-
(Some(prefix), _) => {
245-
quote!(#prefix([#(#value),*]))
246-
}
247-
};
248-
asts.push(ast);
249-
}
250-
251-
let expr: Expr = parse_quote! {
252-
(#(#asts),*)
253-
};
254-
255-
rewritten.push(Stmt::Expr(expr, None));
256-
rewritten
257-
}
258-
259-
/// Rewrite a function call, removing the mutable reference arguments and
260-
/// let-binding return values for them instead.
261-
fn rewrite_fn_call(mut call: ExprCall) -> Local {
262-
let mut args = Punctuated::new();
263-
let mut output = Punctuated::new();
264-
265-
for arg in &call.args {
266-
if let Expr::Reference(ExprReference {
267-
mutability: Some(_),
268-
expr,
269-
..
270-
}) = arg
271-
{
272-
match expr.deref() {
273-
Expr::Path(ExprPath {
274-
path: Path { segments, .. },
275-
..
276-
}) => {
277-
assert_eq!(segments.len(), 1, "expected only one segment in fn arg");
278-
let ident = segments.first().unwrap().ident.clone();
279-
280-
output.push(Pat::Ident(PatIdent {
281-
attrs: Vec::new(),
282-
by_ref: None,
283-
mutability: None,
284-
ident,
285-
subpat: None,
286-
}));
287-
}
288-
other => panic!("unexpected expr in fn arg: {:?}", other),
289-
}
290-
291-
continue;
292-
}
293-
294-
args.push(arg.clone());
295-
}
296-
297-
// Overwrite call arguments with the ones that aren't mutable references
298-
call.args = args;
299-
300-
let pat = Pat::Tuple(PatTuple {
301-
attrs: Vec::new(),
302-
paren_token: Paren::default(),
303-
elems: output,
304-
});
305-
306-
Local {
307-
attrs: Vec::new(),
308-
let_token: Let::default(),
309-
pat,
310-
init: Some(LocalInit {
311-
eq_token: Eq::default(),
312-
expr: Box::new(Expr::Call(call)),
313-
diverge: None,
314-
}),
315-
semi_token: Semi::default(),
316-
}
317-
}

fiat-constify/src/outputs.rs

-77
This file was deleted.

fiat-constify/src/type_registry.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ impl TypeRegistry {
1919
}
2020

2121
/// Add a type which is a new type to the type registry.
22-
pub fn add_new_type(&mut self, item_struct: &ItemStruct) {
22+
pub fn add_newtype(&mut self, item_struct: &ItemStruct) {
2323
if self
2424
.0
2525
.insert(item_struct.ident.clone(), Type::NewType)
@@ -48,7 +48,7 @@ impl TypeRegistry {
4848
}
4949

5050
#[inline]
51-
pub fn is_new_type(&self, ident: &syn::Ident) -> bool {
51+
pub fn is_newtype(&self, ident: &syn::Ident) -> bool {
5252
matches!(self.get(ident), Some(Type::NewType))
5353
}
5454
}

0 commit comments

Comments
 (0)