Skip to content

Commit ba5b4eb

Browse files
derive(SmartPointer): rewrite bounds in where and generic bounds
1 parent 366e558 commit ba5b4eb

File tree

2 files changed

+286
-11
lines changed

2 files changed

+286
-11
lines changed

compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs

+208-11
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,29 @@
11
use std::mem::swap;
22

33
use ast::HasAttrs;
4+
use rustc_ast::mut_visit::MutVisitor;
45
use rustc_ast::{
56
self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem,
67
TraitBoundModifiers, VariantData,
78
};
89
use rustc_attr as attr;
10+
use rustc_data_structures::flat_map_in_place::FlatMapInPlace;
911
use rustc_expand::base::{Annotatable, ExtCtxt};
1012
use rustc_span::symbol::{sym, Ident};
11-
use rustc_span::Span;
13+
use rustc_span::{Span, Symbol};
1214
use smallvec::{smallvec, SmallVec};
1315
use thin_vec::{thin_vec, ThinVec};
1416

17+
type AstTy = ast::ptr::P<ast::Ty>;
18+
1519
macro_rules! path {
1620
($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] }
1721
}
1822

23+
macro_rules! symbols {
24+
($($part:ident)::*) => { [$(sym::$part),*] }
25+
}
26+
1927
pub fn expand_deriving_smart_ptr(
2028
cx: &ExtCtxt<'_>,
2129
span: Span,
@@ -143,31 +151,220 @@ pub fn expand_deriving_smart_ptr(
143151

144152
// Find the `#[pointee]` parameter and add an `Unsize<__S>` bound to it.
145153
let mut impl_generics = generics.clone();
154+
let pointee_ty_ident = generics.params[pointee_param_idx].ident;
155+
let mut self_bounds;
146156
{
147157
let p = &mut impl_generics.params[pointee_param_idx];
158+
self_bounds = p.bounds.clone();
148159
let arg = GenericArg::Type(s_ty.clone());
149160
let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]);
150161
p.bounds.push(cx.trait_bound(unsize, false));
151162
let mut attrs = thin_vec![];
152163
swap(&mut p.attrs, &mut attrs);
153164
p.attrs = attrs.into_iter().filter(|attr| !attr.has_name(sym::pointee)).collect();
154165
}
166+
// We should not set default values to constant generic parameters
167+
// and commute bounds that indirectly involves `#[pointee]`.
168+
for (params, orig_params) in impl_generics.params[pointee_param_idx + 1..]
169+
.iter_mut()
170+
.zip(&generics.params[pointee_param_idx + 1..])
171+
{
172+
if let ast::GenericParamKind::Const { default, .. } = &mut params.kind {
173+
*default = None;
174+
}
175+
for bound in &orig_params.bounds {
176+
let mut bound = bound.clone();
177+
let mut substitution = TypeSubstitution {
178+
from_name: pointee_ty_ident.name,
179+
to_ty: &s_ty,
180+
rewritten: false,
181+
};
182+
substitution.visit_param_bound(&mut bound);
183+
if substitution.rewritten {
184+
params.bounds.push(bound);
185+
}
186+
}
187+
}
155188

156189
// Add the `__S: ?Sized` extra parameter to the impl block.
190+
// We should also commute the bounds from `#[pointee]` to `__S` as required by `Unsize<__S>`.
157191
let sized = cx.path_global(span, path!(span, core::marker::Sized));
158-
let bound = GenericBound::Trait(
159-
cx.poly_trait_ref(span, sized),
160-
TraitBoundModifiers {
161-
polarity: ast::BoundPolarity::Maybe(span),
162-
constness: ast::BoundConstness::Never,
163-
asyncness: ast::BoundAsyncness::Normal,
164-
},
165-
);
166-
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None);
167-
impl_generics.params.push(extra_param);
192+
if self_bounds.iter().all(|bound| {
193+
if let GenericBound::Trait(
194+
trait_ref,
195+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
196+
) = bound
197+
{
198+
!is_sized_marker(&trait_ref.trait_ref.path)
199+
} else {
200+
false
201+
}
202+
}) {
203+
self_bounds.push(GenericBound::Trait(
204+
cx.poly_trait_ref(span, sized),
205+
TraitBoundModifiers {
206+
polarity: ast::BoundPolarity::Maybe(span),
207+
constness: ast::BoundConstness::Never,
208+
asyncness: ast::BoundAsyncness::Normal,
209+
},
210+
));
211+
}
212+
{
213+
let mut substitution =
214+
TypeSubstitution { from_name: pointee_ty_ident.name, to_ty: &s_ty, rewritten: false };
215+
for bound in &mut self_bounds {
216+
substitution.visit_param_bound(bound);
217+
}
218+
}
219+
220+
// We should also commute the where bounds from `#[pointee]` to `__S`
221+
// as well as any bound that indirectly involves the `#[pointee]` type.
222+
for bound in &generics.where_clause.predicates {
223+
if let ast::WherePredicate::BoundPredicate(bound) = bound {
224+
let bound_on_pointee = bound
225+
.bounded_ty
226+
.kind
227+
.is_simple_path()
228+
.map_or(false, |name| name == pointee_ty_ident.name);
229+
230+
let bounds: Vec<_> = bound
231+
.bounds
232+
.iter()
233+
.filter(|bound| {
234+
if let GenericBound::Trait(
235+
trait_ref,
236+
TraitBoundModifiers { polarity: ast::BoundPolarity::Maybe(_), .. },
237+
) = bound
238+
{
239+
!bound_on_pointee || !is_sized_marker(&trait_ref.trait_ref.path)
240+
} else {
241+
true
242+
}
243+
})
244+
.cloned()
245+
.collect();
246+
let mut substitution = TypeSubstitution {
247+
from_name: pointee_ty_ident.name,
248+
to_ty: &s_ty,
249+
rewritten: bounds.len() != bound.bounds.len(),
250+
};
251+
let mut predicate = ast::WherePredicate::BoundPredicate(ast::WhereBoundPredicate {
252+
span: bound.span,
253+
bound_generic_params: bound.bound_generic_params.clone(),
254+
bounded_ty: bound.bounded_ty.clone(),
255+
bounds,
256+
});
257+
substitution.visit_where_predicate(&mut predicate);
258+
if substitution.rewritten {
259+
impl_generics.where_clause.predicates.push(predicate);
260+
}
261+
}
262+
}
263+
264+
let extra_param = cx.typaram(span, Ident::new(sym::__S, span), self_bounds, None);
265+
impl_generics.params.insert(pointee_param_idx + 1, extra_param);
168266

169267
// Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`.
170268
let gen_args = vec![GenericArg::Type(alt_self_type.clone())];
171269
add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone());
172270
add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone());
173271
}
272+
273+
fn is_sized_marker(path: &ast::Path) -> bool {
274+
const CORE_UNSIZE: [Symbol; 3] = symbols!(core::marker::Sized);
275+
const STD_UNSIZE: [Symbol; 3] = symbols!(std::marker::Sized);
276+
if path.segments.len() == 3 {
277+
path.segments.iter().zip(CORE_UNSIZE).all(|(segment, symbol)| segment.ident.name == symbol)
278+
|| path
279+
.segments
280+
.iter()
281+
.zip(STD_UNSIZE)
282+
.all(|(segment, symbol)| segment.ident.name == symbol)
283+
} else {
284+
*path == sym::Sized
285+
}
286+
}
287+
288+
struct TypeSubstitution<'a> {
289+
from_name: Symbol,
290+
to_ty: &'a AstTy,
291+
rewritten: bool,
292+
}
293+
294+
impl<'a> ast::mut_visit::MutVisitor for TypeSubstitution<'a> {
295+
fn visit_ty(&mut self, ty: &mut AstTy) {
296+
if let Some(name) = ty.kind.is_simple_path()
297+
&& name == self.from_name
298+
{
299+
*ty = self.to_ty.clone();
300+
self.rewritten = true;
301+
return;
302+
}
303+
match &mut ty.kind {
304+
ast::TyKind::Slice(_)
305+
| ast::TyKind::Array(_, _)
306+
| ast::TyKind::Ptr(_)
307+
| ast::TyKind::Ref(_, _)
308+
| ast::TyKind::BareFn(_)
309+
| ast::TyKind::Never
310+
| ast::TyKind::Tup(_)
311+
| ast::TyKind::AnonStruct(_, _)
312+
| ast::TyKind::AnonUnion(_, _)
313+
| ast::TyKind::Path(_, _)
314+
| ast::TyKind::TraitObject(_, _)
315+
| ast::TyKind::ImplTrait(_, _)
316+
| ast::TyKind::Paren(_)
317+
| ast::TyKind::Typeof(_)
318+
| ast::TyKind::Infer
319+
| ast::TyKind::MacCall(_)
320+
| ast::TyKind::Pat(_, _) => ast::mut_visit::noop_visit_ty(ty, self),
321+
ast::TyKind::ImplicitSelf
322+
| ast::TyKind::CVarArgs
323+
| ast::TyKind::Dummy
324+
| ast::TyKind::Err(_) => {}
325+
}
326+
}
327+
328+
fn visit_param_bound(&mut self, bound: &mut GenericBound) {
329+
match bound {
330+
GenericBound::Trait(trait_ref, _) => {
331+
if trait_ref
332+
.bound_generic_params
333+
.iter()
334+
.any(|param| param.ident.name == self.from_name)
335+
{
336+
return;
337+
}
338+
self.visit_poly_trait_ref(trait_ref);
339+
}
340+
341+
GenericBound::Use(args, _span) => {
342+
for arg in args {
343+
self.visit_precise_capturing_arg(arg);
344+
}
345+
}
346+
GenericBound::Outlives(_) => {}
347+
}
348+
}
349+
350+
fn visit_where_predicate(&mut self, where_predicate: &mut ast::WherePredicate) {
351+
match where_predicate {
352+
rustc_ast::WherePredicate::BoundPredicate(bound) => {
353+
if bound.bound_generic_params.iter().any(|param| param.ident.name == self.from_name)
354+
{
355+
// Name is shadowed so we must skip the rest
356+
return;
357+
}
358+
bound
359+
.bound_generic_params
360+
.flat_map_in_place(|param| self.flat_map_generic_param(param));
361+
self.visit_ty(&mut bound.bounded_ty);
362+
for bound in &mut bound.bounds {
363+
self.visit_param_bound(bound)
364+
}
365+
}
366+
rustc_ast::WherePredicate::RegionPredicate(_)
367+
| rustc_ast::WherePredicate::EqPredicate(_) => {}
368+
}
369+
}
370+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//@ check-pass
2+
3+
#![feature(derive_smart_pointer)]
4+
5+
#[derive(core::marker::SmartPointer)]
6+
#[repr(transparent)]
7+
pub struct Ptr<'a, #[pointee] T: OnDrop + ?Sized, X> {
8+
data: &'a mut T,
9+
x: core::marker::PhantomData<X>,
10+
}
11+
12+
pub trait OnDrop {
13+
fn on_drop(&mut self);
14+
}
15+
16+
#[derive(core::marker::SmartPointer)]
17+
#[repr(transparent)]
18+
pub struct Ptr2<'a, #[pointee] T: ?Sized, X>
19+
where
20+
T: OnDrop,
21+
{
22+
data: &'a mut T,
23+
x: core::marker::PhantomData<X>,
24+
}
25+
26+
pub trait MyTrait<T: ?Sized> {}
27+
28+
#[derive(core::marker::SmartPointer)]
29+
#[repr(transparent)]
30+
pub struct Ptr3<'a, #[pointee] T: ?Sized, X>
31+
where
32+
T: MyTrait<T>,
33+
{
34+
data: &'a mut T,
35+
x: core::marker::PhantomData<X>,
36+
}
37+
38+
#[derive(core::marker::SmartPointer)]
39+
#[repr(transparent)]
40+
pub struct Ptr4<'a, #[pointee] T: MyTrait<T> + ?Sized, X> {
41+
data: &'a mut T,
42+
x: core::marker::PhantomData<X>,
43+
}
44+
45+
#[derive(core::marker::SmartPointer)]
46+
#[repr(transparent)]
47+
pub struct Ptr5<'a, #[pointee] T: ?Sized, X>
48+
where
49+
Ptr5Companion<T>: MyTrait<T>,
50+
Ptr5Companion2: MyTrait<T>,
51+
{
52+
data: &'a mut T,
53+
x: core::marker::PhantomData<X>,
54+
}
55+
56+
pub struct Ptr5Companion<T: ?Sized>(core::marker::PhantomData<T>);
57+
pub struct Ptr5Companion2;
58+
59+
#[derive(core::marker::SmartPointer)]
60+
#[repr(transparent)]
61+
pub struct Ptr6<'a, #[pointee] T: ?Sized, X: MyTrait<T>> {
62+
data: &'a mut T,
63+
x: core::marker::PhantomData<X>,
64+
}
65+
66+
// a reduced example from https://lore.kernel.org/all/[email protected]/
67+
#[repr(transparent)]
68+
#[derive(core::marker::SmartPointer)]
69+
pub struct ListArc<#[pointee] T, const ID: u64 = 0>
70+
where
71+
T: ListArcSafe<ID> + ?Sized,
72+
{
73+
arc: *const T,
74+
}
75+
76+
pub trait ListArcSafe<const ID: u64> {}
77+
78+
fn main() {}

0 commit comments

Comments
 (0)