Skip to content

Commit 2a94e1c

Browse files
Rollup merge of rust-lang#132911 - compiler-errors:async-fn-sugar, r=fmease
Pretty print async fn sugar in opaques and trait bounds sudo r? fmease
2 parents d2fb8b5 + 4c53ad5 commit 2a94e1c

File tree

8 files changed

+134
-138
lines changed

8 files changed

+134
-138
lines changed

compiler/rustc_middle/src/middle/lang_items.rs

+11
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ impl<'tcx> TyCtxt<'tcx> {
6868
}
6969
}
7070

71+
/// Given a [`ty::ClosureKind`], get the [`DefId`] of its corresponding `Fn`-family
72+
/// trait, if it is defined.
73+
pub fn async_fn_trait_kind_to_def_id(self, kind: ty::ClosureKind) -> Option<DefId> {
74+
let items = self.lang_items();
75+
match kind {
76+
ty::ClosureKind::Fn => items.async_fn_trait(),
77+
ty::ClosureKind::FnMut => items.async_fn_mut_trait(),
78+
ty::ClosureKind::FnOnce => items.async_fn_once_trait(),
79+
}
80+
}
81+
7182
/// Returns `true` if `id` is a `DefId` of [`Fn`], [`FnMut`] or [`FnOnce`] traits.
7283
pub fn is_fn_trait(self, id: DefId) -> bool {
7384
self.fn_trait_kind_from_def_id(id).is_some()

compiler/rustc_middle/src/ty/print/pretty.rs

+85-118
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ use rustc_hir::definitions::{DefKey, DefPathDataName};
1616
use rustc_macros::{Lift, extension};
1717
use rustc_session::Limit;
1818
use rustc_session::cstore::{ExternCrate, ExternCrateSource};
19-
use rustc_span::FileNameDisplayPreference;
2019
use rustc_span::symbol::{Ident, Symbol, kw};
20+
use rustc_span::{FileNameDisplayPreference, sym};
2121
use rustc_type_ir::{Upcast as _, elaborate};
2222
use smallvec::SmallVec;
2323

@@ -26,8 +26,8 @@ use super::*;
2626
use crate::mir::interpret::{AllocRange, GlobalAlloc, Pointer, Provenance, Scalar};
2727
use crate::query::{IntoQueryParam, Providers};
2828
use crate::ty::{
29-
ConstInt, Expr, GenericArgKind, ParamConst, ScalarInt, Term, TermKind, TypeFoldable,
30-
TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
29+
ConstInt, Expr, GenericArgKind, ParamConst, ScalarInt, Term, TermKind, TraitPredicate,
30+
TypeFoldable, TypeSuperFoldable, TypeSuperVisitable, TypeVisitable, TypeVisitableExt,
3131
};
3232

3333
macro_rules! p {
@@ -993,10 +993,8 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
993993

994994
match bound_predicate.skip_binder() {
995995
ty::ClauseKind::Trait(pred) => {
996-
let trait_ref = bound_predicate.rebind(pred.trait_ref);
997-
998996
// Don't print `+ Sized`, but rather `+ ?Sized` if absent.
999-
if tcx.is_lang_item(trait_ref.def_id(), LangItem::Sized) {
997+
if tcx.is_lang_item(pred.def_id(), LangItem::Sized) {
1000998
match pred.polarity {
1001999
ty::PredicatePolarity::Positive => {
10021000
has_sized_bound = true;
@@ -1007,24 +1005,22 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
10071005
}
10081006

10091007
self.insert_trait_and_projection(
1010-
trait_ref,
1011-
pred.polarity,
1008+
bound_predicate.rebind(pred),
10121009
None,
10131010
&mut traits,
10141011
&mut fn_traits,
10151012
);
10161013
}
10171014
ty::ClauseKind::Projection(pred) => {
1018-
let proj_ref = bound_predicate.rebind(pred);
1019-
let trait_ref = proj_ref.required_poly_trait_ref(tcx);
1020-
1021-
// Projection type entry -- the def-id for naming, and the ty.
1022-
let proj_ty = (proj_ref.projection_def_id(), proj_ref.term());
1015+
let proj = bound_predicate.rebind(pred);
1016+
let trait_ref = proj.map_bound(|proj| TraitPredicate {
1017+
trait_ref: proj.projection_term.trait_ref(tcx),
1018+
polarity: ty::PredicatePolarity::Positive,
1019+
});
10231020

10241021
self.insert_trait_and_projection(
10251022
trait_ref,
1026-
ty::PredicatePolarity::Positive,
1027-
Some(proj_ty),
1023+
Some((proj.projection_def_id(), proj.term())),
10281024
&mut traits,
10291025
&mut fn_traits,
10301026
);
@@ -1042,88 +1038,66 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
10421038
// Insert parenthesis around (Fn(A, B) -> C) if the opaque ty has more than one other trait
10431039
let paren_needed = fn_traits.len() > 1 || traits.len() > 0 || !has_sized_bound;
10441040

1045-
for (fn_once_trait_ref, entry) in fn_traits {
1041+
for ((bound_args, is_async), entry) in fn_traits {
10461042
write!(self, "{}", if first { "" } else { " + " })?;
10471043
write!(self, "{}", if paren_needed { "(" } else { "" })?;
10481044

1049-
self.wrap_binder(&fn_once_trait_ref, |trait_ref, cx| {
1050-
define_scoped_cx!(cx);
1051-
// Get the (single) generic ty (the args) of this FnOnce trait ref.
1052-
let generics = tcx.generics_of(trait_ref.def_id);
1053-
let own_args = generics.own_args_no_defaults(tcx, trait_ref.args);
1054-
1055-
match (entry.return_ty, own_args[0].expect_ty()) {
1056-
// We can only print `impl Fn() -> ()` if we have a tuple of args and we recorded
1057-
// a return type.
1058-
(Some(return_ty), arg_tys) if matches!(arg_tys.kind(), ty::Tuple(_)) => {
1059-
let name = if entry.fn_trait_ref.is_some() {
1060-
"Fn"
1061-
} else if entry.fn_mut_trait_ref.is_some() {
1062-
"FnMut"
1063-
} else {
1064-
"FnOnce"
1065-
};
1066-
1067-
p!(write("{}(", name));
1045+
let trait_def_id = if is_async {
1046+
tcx.async_fn_trait_kind_to_def_id(entry.kind).expect("expected AsyncFn lang items")
1047+
} else {
1048+
tcx.fn_trait_kind_to_def_id(entry.kind).expect("expected Fn lang items")
1049+
};
10681050

1069-
for (idx, ty) in arg_tys.tuple_fields().iter().enumerate() {
1070-
if idx > 0 {
1071-
p!(", ");
1072-
}
1073-
p!(print(ty));
1074-
}
1051+
if let Some(return_ty) = entry.return_ty {
1052+
self.wrap_binder(&bound_args, |args, cx| {
1053+
define_scoped_cx!(cx);
1054+
p!(write("{}", tcx.item_name(trait_def_id)));
1055+
p!("(");
10751056

1076-
p!(")");
1077-
if let Some(ty) = return_ty.skip_binder().as_type() {
1078-
if !ty.is_unit() {
1079-
p!(" -> ", print(return_ty));
1080-
}
1057+
for (idx, ty) in args.iter().enumerate() {
1058+
if idx > 0 {
1059+
p!(", ");
10811060
}
1082-
p!(write("{}", if paren_needed { ")" } else { "" }));
1083-
1084-
first = false;
1061+
p!(print(ty));
10851062
}
1086-
// If we got here, we can't print as a `impl Fn(A, B) -> C`. Just record the
1087-
// trait_refs we collected in the OpaqueFnEntry as normal trait refs.
1088-
_ => {
1089-
if entry.has_fn_once {
1090-
traits
1091-
.entry((fn_once_trait_ref, ty::PredicatePolarity::Positive))
1092-
.or_default()
1093-
.extend(
1094-
// Group the return ty with its def id, if we had one.
1095-
entry.return_ty.map(|ty| {
1096-
(tcx.require_lang_item(LangItem::FnOnceOutput, None), ty)
1097-
}),
1098-
);
1099-
}
1100-
if let Some(trait_ref) = entry.fn_mut_trait_ref {
1101-
traits.entry((trait_ref, ty::PredicatePolarity::Positive)).or_default();
1102-
}
1103-
if let Some(trait_ref) = entry.fn_trait_ref {
1104-
traits.entry((trait_ref, ty::PredicatePolarity::Positive)).or_default();
1063+
1064+
p!(")");
1065+
if let Some(ty) = return_ty.skip_binder().as_type() {
1066+
if !ty.is_unit() {
1067+
p!(" -> ", print(return_ty));
11051068
}
11061069
}
1107-
}
1070+
p!(write("{}", if paren_needed { ")" } else { "" }));
11081071

1109-
Ok(())
1110-
})?;
1072+
first = false;
1073+
Ok(())
1074+
})?;
1075+
} else {
1076+
// Otherwise, render this like a regular trait.
1077+
traits.insert(
1078+
bound_args.map_bound(|args| ty::TraitPredicate {
1079+
polarity: ty::PredicatePolarity::Positive,
1080+
trait_ref: ty::TraitRef::new(tcx, trait_def_id, [Ty::new_tup(tcx, args)]),
1081+
}),
1082+
FxIndexMap::default(),
1083+
);
1084+
}
11111085
}
11121086

11131087
// Print the rest of the trait types (that aren't Fn* family of traits)
1114-
for ((trait_ref, polarity), assoc_items) in traits {
1088+
for (trait_pred, assoc_items) in traits {
11151089
write!(self, "{}", if first { "" } else { " + " })?;
11161090

1117-
self.wrap_binder(&trait_ref, |trait_ref, cx| {
1091+
self.wrap_binder(&trait_pred, |trait_pred, cx| {
11181092
define_scoped_cx!(cx);
11191093

1120-
if polarity == ty::PredicatePolarity::Negative {
1094+
if trait_pred.polarity == ty::PredicatePolarity::Negative {
11211095
p!("!");
11221096
}
1123-
p!(print(trait_ref.print_only_trait_name()));
1097+
p!(print(trait_pred.trait_ref.print_only_trait_name()));
11241098

1125-
let generics = tcx.generics_of(trait_ref.def_id);
1126-
let own_args = generics.own_args_no_defaults(tcx, trait_ref.args);
1099+
let generics = tcx.generics_of(trait_pred.def_id());
1100+
let own_args = generics.own_args_no_defaults(tcx, trait_pred.trait_ref.args);
11271101

11281102
if !own_args.is_empty() || !assoc_items.is_empty() {
11291103
let mut first = true;
@@ -1230,51 +1204,48 @@ pub trait PrettyPrinter<'tcx>: Printer<'tcx> + fmt::Write {
12301204
/// traits map or fn_traits map, depending on if the trait is in the Fn* family of traits.
12311205
fn insert_trait_and_projection(
12321206
&mut self,
1233-
trait_ref: ty::PolyTraitRef<'tcx>,
1234-
polarity: ty::PredicatePolarity,
1207+
trait_pred: ty::PolyTraitPredicate<'tcx>,
12351208
proj_ty: Option<(DefId, ty::Binder<'tcx, Term<'tcx>>)>,
12361209
traits: &mut FxIndexMap<
1237-
(ty::PolyTraitRef<'tcx>, ty::PredicatePolarity),
1210+
ty::PolyTraitPredicate<'tcx>,
12381211
FxIndexMap<DefId, ty::Binder<'tcx, Term<'tcx>>>,
12391212
>,
1240-
fn_traits: &mut FxIndexMap<ty::PolyTraitRef<'tcx>, OpaqueFnEntry<'tcx>>,
1213+
fn_traits: &mut FxIndexMap<
1214+
(ty::Binder<'tcx, &'tcx ty::List<Ty<'tcx>>>, bool),
1215+
OpaqueFnEntry<'tcx>,
1216+
>,
12411217
) {
1242-
let trait_def_id = trait_ref.def_id();
1243-
1244-
// If our trait_ref is FnOnce or any of its children, project it onto the parent FnOnce
1245-
// super-trait ref and record it there.
1246-
// We skip negative Fn* bounds since they can't use parenthetical notation anyway.
1247-
if polarity == ty::PredicatePolarity::Positive
1248-
&& let Some(fn_once_trait) = self.tcx().lang_items().fn_once_trait()
1249-
{
1250-
// If we have a FnOnce, then insert it into
1251-
if trait_def_id == fn_once_trait {
1252-
let entry = fn_traits.entry(trait_ref).or_default();
1253-
// Optionally insert the return_ty as well.
1254-
if let Some((_, ty)) = proj_ty {
1255-
entry.return_ty = Some(ty);
1256-
}
1257-
entry.has_fn_once = true;
1258-
return;
1259-
} else if self.tcx().is_lang_item(trait_def_id, LangItem::FnMut) {
1260-
let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
1261-
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
1262-
.unwrap();
1218+
let tcx = self.tcx();
1219+
let trait_def_id = trait_pred.def_id();
12631220

1264-
fn_traits.entry(super_trait_ref).or_default().fn_mut_trait_ref = Some(trait_ref);
1265-
return;
1266-
} else if self.tcx().is_lang_item(trait_def_id, LangItem::Fn) {
1267-
let super_trait_ref = elaborate::supertraits(self.tcx(), trait_ref)
1268-
.find(|super_trait_ref| super_trait_ref.def_id() == fn_once_trait)
1269-
.unwrap();
1221+
let fn_trait_and_async = if let Some(kind) = tcx.fn_trait_kind_from_def_id(trait_def_id) {
1222+
Some((kind, false))
1223+
} else if let Some(kind) = tcx.async_fn_trait_kind_from_def_id(trait_def_id) {
1224+
Some((kind, true))
1225+
} else {
1226+
None
1227+
};
12701228

1271-
fn_traits.entry(super_trait_ref).or_default().fn_trait_ref = Some(trait_ref);
1272-
return;
1229+
if trait_pred.polarity() == ty::PredicatePolarity::Positive
1230+
&& let Some((kind, is_async)) = fn_trait_and_async
1231+
&& let ty::Tuple(types) = *trait_pred.skip_binder().trait_ref.args.type_at(1).kind()
1232+
{
1233+
let entry = fn_traits
1234+
.entry((trait_pred.rebind(types), is_async))
1235+
.or_insert_with(|| OpaqueFnEntry { kind, return_ty: None });
1236+
if kind.extends(entry.kind) {
1237+
entry.kind = kind;
1238+
}
1239+
if let Some((proj_def_id, proj_ty)) = proj_ty
1240+
&& tcx.item_name(proj_def_id) == sym::Output
1241+
{
1242+
entry.return_ty = Some(proj_ty);
12731243
}
1244+
return;
12741245
}
12751246

12761247
// Otherwise, just group our traits and projection types.
1277-
traits.entry((trait_ref, polarity)).or_default().extend(proj_ty);
1248+
traits.entry(trait_pred).or_default().extend(proj_ty);
12781249
}
12791250

12801251
fn pretty_print_inherent_projection(
@@ -3189,10 +3160,10 @@ define_print_and_forward_display! {
31893160

31903161
TraitRefPrintSugared<'tcx> {
31913162
if !with_reduced_queries()
3192-
&& let Some(kind) = cx.tcx().fn_trait_kind_from_def_id(self.0.def_id)
3163+
&& cx.tcx().trait_def(self.0.def_id).paren_sugar
31933164
&& let ty::Tuple(args) = self.0.args.type_at(1).kind()
31943165
{
3195-
p!(write("{}", kind.as_str()), "(");
3166+
p!(write("{}", cx.tcx().item_name(self.0.def_id)), "(");
31963167
for (i, arg) in args.iter().enumerate() {
31973168
if i > 0 {
31983169
p!(", ");
@@ -3415,11 +3386,7 @@ pub fn provide(providers: &mut Providers) {
34153386
*providers = Providers { trimmed_def_paths, ..*providers };
34163387
}
34173388

3418-
#[derive(Default)]
34193389
pub struct OpaqueFnEntry<'tcx> {
3420-
// The trait ref is already stored as a key, so just track if we have it as a real predicate
3421-
has_fn_once: bool,
3422-
fn_mut_trait_ref: Option<ty::PolyTraitRef<'tcx>>,
3423-
fn_trait_ref: Option<ty::PolyTraitRef<'tcx>>,
3390+
kind: ty::ClosureKind,
34243391
return_ty: Option<ty::Binder<'tcx, Term<'tcx>>>,
34253392
}

compiler/rustc_type_ir/src/predicate.rs

-13
Original file line numberDiff line numberDiff line change
@@ -684,19 +684,6 @@ impl<I: Interner> ty::Binder<I, ProjectionPredicate<I>> {
684684
self.skip_binder().projection_term.trait_def_id(cx)
685685
}
686686

687-
/// Get the trait ref required for this projection to be well formed.
688-
/// Note that for generic associated types the predicates of the associated
689-
/// type also need to be checked.
690-
#[inline]
691-
pub fn required_poly_trait_ref(&self, cx: I) -> ty::Binder<I, TraitRef<I>> {
692-
// Note: unlike with `TraitRef::to_poly_trait_ref()`,
693-
// `self.0.trait_ref` is permitted to have escaping regions.
694-
// This is because here `self` has a `Binder` and so does our
695-
// return value, so we are preserving the number of binding
696-
// levels.
697-
self.map_bound(|predicate| predicate.projection_term.trait_ref(cx))
698-
}
699-
700687
pub fn term(&self) -> ty::Binder<I, I::Term> {
701688
self.map_bound(|predicate| predicate.term)
702689
}

tests/ui/async-await/async-closures/fn-exception-target-features.stderr

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
error[E0277]: the trait bound `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}: AsyncFn<()>` is not satisfied
1+
error[E0277]: the trait bound `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}: AsyncFn()` is not satisfied
22
--> $DIR/fn-exception-target-features.rs:16:10
33
|
44
LL | test(target_feature);
5-
| ---- ^^^^^^^^^^^^^^ the trait `AsyncFn<()>` is not implemented for fn item `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}`
5+
| ---- ^^^^^^^^^^^^^^ the trait `AsyncFn()` is not implemented for fn item `fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {target_feature}`
66
| |
77
| required by a bound introduced by this call
88
|

tests/ui/async-await/async-closures/fn-exception.stderr

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
error[E0277]: the trait bound `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}: AsyncFn<()>` is not satisfied
1+
error[E0277]: the trait bound `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}: AsyncFn()` is not satisfied
22
--> $DIR/fn-exception.rs:19:10
33
|
44
LL | test(unsafety);
5-
| ---- ^^^^^^^^ the trait `AsyncFn<()>` is not implemented for fn item `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}`
5+
| ---- ^^^^^^^^ the trait `AsyncFn()` is not implemented for fn item `unsafe fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {unsafety}`
66
| |
77
| required by a bound introduced by this call
88
|
@@ -12,11 +12,11 @@ note: required by a bound in `test`
1212
LL | fn test(f: impl async Fn()) {}
1313
| ^^^^^^^^^^ required by this bound in `test`
1414

15-
error[E0277]: the trait bound `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}: AsyncFn<()>` is not satisfied
15+
error[E0277]: the trait bound `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}: AsyncFn()` is not satisfied
1616
--> $DIR/fn-exception.rs:20:10
1717
|
1818
LL | test(abi);
19-
| ---- ^^^ the trait `AsyncFn<()>` is not implemented for fn item `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}`
19+
| ---- ^^^ the trait `AsyncFn()` is not implemented for fn item `extern "C" fn() -> Pin<Box<(dyn Future<Output = ()> + 'static)>> {abi}`
2020
| |
2121
| required by a bound introduced by this call
2222
|
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
//@ edition: 2021
2+
3+
#![feature(async_closure)]
4+
5+
use std::ops::AsyncFnMut;
6+
7+
fn produce() -> impl AsyncFnMut() -> &'static str {
8+
async || ""
9+
}
10+
11+
fn main() {
12+
let x: i32 = produce();
13+
//~^ ERROR mismatched types
14+
}

0 commit comments

Comments
 (0)