Skip to content

Commit ce63aea

Browse files
committed
add a normalizes-to fast path
1 parent 75cb5c5 commit ce63aea

File tree

5 files changed

+176
-66
lines changed

5 files changed

+176
-66
lines changed

compiler/rustc_middle/src/ty/fast_reject.rs

+70-63
Original file line numberDiff line numberDiff line change
@@ -149,9 +149,9 @@ impl SimplifiedType {
149149
}
150150
}
151151

152-
/// Given generic arguments from an obligation and an impl,
153-
/// could these two be unified after replacing parameters in the
154-
/// the impl with inference variables.
152+
/// Given generic arguments from an obligation and a candidate,
153+
/// could these two be unified after replacing parameters and bound
154+
/// variables in the candidate with inference variables.
155155
///
156156
/// For obligations, parameters won't be replaced by inference
157157
/// variables and only unify with themselves. We treat them
@@ -170,28 +170,30 @@ impl DeepRejectCtxt {
170170
pub fn args_may_unify<'tcx>(
171171
self,
172172
obligation_args: GenericArgsRef<'tcx>,
173-
impl_args: GenericArgsRef<'tcx>,
173+
candidate_args: GenericArgsRef<'tcx>,
174174
) -> bool {
175-
iter::zip(obligation_args, impl_args).all(|(obl, imp)| {
175+
iter::zip(obligation_args, candidate_args).all(|(obl, imp)| {
176176
match (obl.unpack(), imp.unpack()) {
177177
// We don't fast reject based on regions.
178178
(GenericArgKind::Lifetime(_), GenericArgKind::Lifetime(_)) => true,
179-
(GenericArgKind::Type(obl), GenericArgKind::Type(imp)) => {
180-
self.types_may_unify(obl, imp)
179+
(GenericArgKind::Type(obl), GenericArgKind::Type(candidate)) => {
180+
self.types_may_unify(obl, candidate)
181181
}
182-
(GenericArgKind::Const(obl), GenericArgKind::Const(imp)) => {
183-
self.consts_may_unify(obl, imp)
182+
(GenericArgKind::Const(obl), GenericArgKind::Const(candidate)) => {
183+
self.consts_may_unify(obl, candidate)
184184
}
185185
_ => bug!("kind mismatch: {obl} {imp}"),
186186
}
187187
})
188188
}
189189

190-
pub fn types_may_unify<'tcx>(self, obligation_ty: Ty<'tcx>, impl_ty: Ty<'tcx>) -> bool {
191-
match impl_ty.kind() {
192-
// Start by checking whether the type in the impl may unify with
190+
pub fn types_may_unify<'tcx>(self, obligation_ty: Ty<'tcx>, candidate_ty: Ty<'tcx>) -> bool {
191+
match candidate_ty.kind() {
192+
// Start by checking whether the type in the candidate may unify with
193193
// pretty much everything. Just return `true` in that case.
194-
ty::Param(_) | ty::Error(_) | ty::Alias(..) => return true,
194+
ty::Param(_) | ty::Error(_) | ty::Alias(..) | ty::Infer(_) | ty::Placeholder(..) => {
195+
return true;
196+
}
195197
// These types only unify with inference variables or their own
196198
// variant.
197199
ty::Bool
@@ -210,18 +212,15 @@ impl DeepRejectCtxt {
210212
| ty::Never
211213
| ty::Tuple(..)
212214
| ty::FnPtr(..)
213-
| ty::Foreign(..) => debug_assert!(impl_ty.is_known_rigid()),
215+
| ty::Foreign(..) => debug_assert!(candidate_ty.is_known_rigid()),
214216
ty::FnDef(..)
215217
| ty::Closure(..)
216218
| ty::CoroutineClosure(..)
217219
| ty::Coroutine(..)
218220
| ty::CoroutineWitness(..)
219-
| ty::Placeholder(..)
220-
| ty::Bound(..)
221-
| ty::Infer(_) => bug!("unexpected impl_ty: {impl_ty}"),
221+
| ty::Bound(..) => bug!("unexpected candidate_ty: {candidate_ty}"),
222222
}
223223

224-
let k = impl_ty.kind();
225224
match *obligation_ty.kind() {
226225
// Purely rigid types, use structural equivalence.
227226
ty::Bool
@@ -231,65 +230,68 @@ impl DeepRejectCtxt {
231230
| ty::Float(_)
232231
| ty::Str
233232
| ty::Never
234-
| ty::Foreign(_) => obligation_ty == impl_ty,
235-
ty::Ref(_, obl_ty, obl_mutbl) => match k {
236-
&ty::Ref(_, impl_ty, impl_mutbl) => {
237-
obl_mutbl == impl_mutbl && self.types_may_unify(obl_ty, impl_ty)
233+
| ty::Foreign(_) => obligation_ty == candidate_ty,
234+
ty::Ref(_, obl_ty, obl_mutbl) => match candidate_ty.kind() {
235+
&ty::Ref(_, cand_ty, cand_mutbl) => {
236+
obl_mutbl == cand_mutbl && self.types_may_unify(obl_ty, cand_ty)
238237
}
239238
_ => false,
240239
},
241-
ty::Adt(obl_def, obl_args) => match k {
242-
&ty::Adt(impl_def, impl_args) => {
243-
obl_def == impl_def && self.args_may_unify(obl_args, impl_args)
240+
ty::Adt(obl_def, obl_args) => match candidate_ty.kind() {
241+
&ty::Adt(cand_def, cand_args) => {
242+
obl_def == cand_def && self.args_may_unify(obl_args, cand_args)
244243
}
245244
_ => false,
246245
},
247-
ty::Pat(obl_ty, _) => {
246+
ty::Pat(obl_ty, _) => match candidate_ty.kind() {
248247
// FIXME(pattern_types): take pattern into account
249-
matches!(k, &ty::Pat(impl_ty, _) if self.types_may_unify(obl_ty, impl_ty))
250-
}
251-
ty::Slice(obl_ty) => {
252-
matches!(k, &ty::Slice(impl_ty) if self.types_may_unify(obl_ty, impl_ty))
253-
}
254-
ty::Array(obl_ty, obl_len) => match k {
255-
&ty::Array(impl_ty, impl_len) => {
256-
self.types_may_unify(obl_ty, impl_ty)
257-
&& self.consts_may_unify(obl_len, impl_len)
248+
&ty::Pat(cand_ty, _) => self.types_may_unify(obl_ty, cand_ty),
249+
_ => false,
250+
},
251+
ty::Slice(obl_ty) => match candidate_ty.kind() {
252+
&ty::Slice(cand_ty) => self.types_may_unify(obl_ty, cand_ty),
253+
_ => false,
254+
},
255+
ty::Array(obl_ty, obl_len) => match candidate_ty.kind() {
256+
&ty::Array(cand_ty, cand_len) => {
257+
self.types_may_unify(obl_ty, cand_ty)
258+
&& self.consts_may_unify(obl_len, cand_len)
258259
}
259260
_ => false,
260261
},
261-
ty::Tuple(obl) => match k {
262-
&ty::Tuple(imp) => {
263-
obl.len() == imp.len()
264-
&& iter::zip(obl, imp).all(|(obl, imp)| self.types_may_unify(obl, imp))
262+
ty::Tuple(obl) => match candidate_ty.kind() {
263+
&ty::Tuple(cand) => {
264+
obl.len() == cand.len()
265+
&& iter::zip(obl, cand).all(|(obl, cand)| self.types_may_unify(obl, cand))
265266
}
266267
_ => false,
267268
},
268-
ty::RawPtr(obl_ty, obl_mutbl) => match *k {
269-
ty::RawPtr(imp_ty, imp_mutbl) => {
270-
obl_mutbl == imp_mutbl && self.types_may_unify(obl_ty, imp_ty)
269+
ty::RawPtr(obl_ty, obl_mutbl) => match *candidate_ty.kind() {
270+
ty::RawPtr(cand_ty, cand_mutbl) => {
271+
obl_mutbl == cand_mutbl && self.types_may_unify(obl_ty, cand_ty)
271272
}
272273
_ => false,
273274
},
274-
ty::Dynamic(obl_preds, ..) => {
275+
ty::Dynamic(obl_preds, ..) => match candidate_ty.kind() {
275276
// Ideally we would walk the existential predicates here or at least
276277
// compare their length. But considering that the relevant `Relate` impl
277278
// actually sorts and deduplicates these, that doesn't work.
278-
matches!(k, ty::Dynamic(impl_preds, ..) if
279-
obl_preds.principal_def_id() == impl_preds.principal_def_id()
280-
)
281-
}
282-
ty::FnPtr(obl_sig) => match k {
283-
ty::FnPtr(impl_sig) => {
279+
ty::Dynamic(cand_preds, ..) => {
280+
obl_preds.principal_def_id() == cand_preds.principal_def_id()
281+
}
282+
_ => false,
283+
},
284+
ty::FnPtr(obl_sig) => match candidate_ty.kind() {
285+
ty::FnPtr(cand_sig) => {
284286
let ty::FnSig { inputs_and_output, c_variadic, safety, abi } =
285287
obl_sig.skip_binder();
286-
let impl_sig = impl_sig.skip_binder();
288+
let cand_sig = cand_sig.skip_binder();
287289

288-
abi == impl_sig.abi
289-
&& c_variadic == impl_sig.c_variadic
290-
&& safety == impl_sig.safety
291-
&& inputs_and_output.len() == impl_sig.inputs_and_output.len()
292-
&& iter::zip(inputs_and_output, impl_sig.inputs_and_output)
290+
abi == cand_sig.abi
291+
&& c_variadic == cand_sig.c_variadic
292+
&& safety == cand_sig.safety
293+
&& inputs_and_output.len() == cand_sig.inputs_and_output.len()
294+
&& iter::zip(inputs_and_output, cand_sig.inputs_and_output)
293295
.all(|(obl, imp)| self.types_may_unify(obl, imp))
294296
}
295297
_ => false,
@@ -308,9 +310,9 @@ impl DeepRejectCtxt {
308310
TreatParams::AsCandidateKey => true,
309311
},
310312

311-
ty::Infer(ty::IntVar(_)) => impl_ty.is_integral(),
313+
ty::Infer(ty::IntVar(_)) => candidate_ty.is_integral(),
312314

313-
ty::Infer(ty::FloatVar(_)) => impl_ty.is_floating_point(),
315+
ty::Infer(ty::FloatVar(_)) => candidate_ty.is_floating_point(),
314316

315317
ty::Infer(_) => true,
316318

@@ -329,17 +331,22 @@ impl DeepRejectCtxt {
329331
}
330332
}
331333

332-
pub fn consts_may_unify(self, obligation_ct: ty::Const<'_>, impl_ct: ty::Const<'_>) -> bool {
333-
let impl_val = match impl_ct.kind() {
334+
pub fn consts_may_unify(
335+
self,
336+
obligation_ct: ty::Const<'_>,
337+
candidate_ct: ty::Const<'_>,
338+
) -> bool {
339+
let candidate_val = match candidate_ct.kind() {
334340
ty::ConstKind::Expr(_)
335341
| ty::ConstKind::Param(_)
336342
| ty::ConstKind::Unevaluated(_)
343+
| ty::ConstKind::Placeholder(_)
337344
| ty::ConstKind::Error(_) => {
338345
return true;
339346
}
340-
ty::ConstKind::Value(impl_val) => impl_val,
341-
ty::ConstKind::Infer(_) | ty::ConstKind::Bound(..) | ty::ConstKind::Placeholder(_) => {
342-
bug!("unexpected impl arg: {:?}", impl_ct)
347+
ty::ConstKind::Value(candidate_val) => candidate_val,
348+
ty::ConstKind::Infer(_) | ty::ConstKind::Bound(..) => {
349+
bug!("unexpected candidate arg: {:?}", candidate_ct)
343350
}
344351
};
345352

@@ -357,7 +364,7 @@ impl DeepRejectCtxt {
357364
ty::ConstKind::Expr(_) | ty::ConstKind::Unevaluated(_) | ty::ConstKind::Error(_) => {
358365
true
359366
}
360-
ty::ConstKind::Value(obl_val) => obl_val == impl_val,
367+
ty::ConstKind::Value(obl_val) => obl_val == candidate_val,
361368

362369
ty::ConstKind::Infer(_) => true,
363370

compiler/rustc_middle/src/ty/sty.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1901,6 +1901,11 @@ impl<'tcx> Ty<'tcx> {
19011901
matches!(self.kind(), Infer(FreshTy(_) | FreshIntTy(_) | FreshFloatTy(_)))
19021902
}
19031903

1904+
#[inline]
1905+
pub fn is_placeholder(self) -> bool {
1906+
matches!(self.kind(), Placeholder(_))
1907+
}
1908+
19041909
#[inline]
19051910
pub fn is_char(self) -> bool {
19061911
matches!(self.kind(), Char)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use rustc_middle::ty::{self, TyCtxt};
2+
use rustc_span::def_id::DefId;
3+
/// We early return `NoSolution` when trying to normalize associated types if
4+
/// we know them to be rigid. This is necessary if there are a huge amount of
5+
/// rigid associated types in the `ParamEnv` as we would otherwise get hangs
6+
/// when trying to normalize each associated type with all other associated types.
7+
///
8+
/// See trait-system-refactor-initiative#109 for an example.
9+
///
10+
/// ```
11+
/// is_rigid_alias(alias) :-
12+
/// is_placeholder(alias.self_ty),
13+
/// no_blanket_impls(alias.trait_def_id),
14+
/// not(may_normalize_via_env(alias)),
15+
///
16+
/// may_normalize_via_env(alias) :- exists<projection_clause> {
17+
/// projection_clause.def_id == alias.def_id,
18+
/// projection_clause.args may_unify alias.args,
19+
/// }
20+
/// ```
21+
#[instrument(level = "debug", skip(tcx), ret)]
22+
pub(super) fn is_rigid_alias<'tcx>(
23+
tcx: TyCtxt<'tcx>,
24+
param_env: ty::ParamEnv<'tcx>,
25+
alias: ty::AliasTerm<'tcx>,
26+
) -> bool {
27+
// FIXME: This could consider associated types as rigid as long
28+
// as it considers the *recursive* item bounds of the alias,
29+
// which is non-trivial. We may be forced to handle this case
30+
// in the future.
31+
alias.self_ty().is_placeholder()
32+
&& no_blanket_impls(tcx, alias.trait_def_id(tcx))
33+
&& !may_normalize_via_env(param_env, alias)
34+
}
35+
36+
#[instrument(level = "trace", skip(tcx), ret)]
37+
fn no_blanket_impls<'tcx>(tcx: TyCtxt<'tcx>, trait_def_id: DefId) -> bool {
38+
// FIXME(ptr_metadata): There's currently a builtin impl for `Pointee` which
39+
// applies for all `T` as long as `T: Sized` holds. THis impl should
40+
// get removed in favor of `Pointee` being a super trait of `Sized`.
41+
tcx.trait_impls_of(trait_def_id).blanket_impls().is_empty()
42+
&& !tcx.lang_items().pointee_trait().is_some_and(|def_id| trait_def_id == def_id)
43+
}
44+
45+
#[instrument(level = "trace", ret)]
46+
fn may_normalize_via_env<'tcx>(param_env: ty::ParamEnv<'tcx>, alias: ty::AliasTerm<'tcx>) -> bool {
47+
for clause in param_env.caller_bounds() {
48+
let Some(projection_pred) = clause.as_projection_clause() else {
49+
continue;
50+
};
51+
52+
if projection_pred.projection_def_id() != alias.def_id {
53+
continue;
54+
};
55+
56+
let drcx = ty::fast_reject::DeepRejectCtxt {
57+
treat_obligation_params: ty::fast_reject::TreatParams::ForLookup,
58+
};
59+
if drcx.args_may_unify(alias.args, projection_pred.skip_binder().projection_term.args) {
60+
return true;
61+
}
62+
}
63+
64+
false
65+
}

compiler/rustc_trait_selection/src/solve/normalizes_to/mod.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use rustc_middle::{bug, span_bug};
2121
use rustc_span::{sym, ErrorGuaranteed, DUMMY_SP};
2222

2323
mod anon_const;
24+
mod fast_reject;
2425
mod inherent;
2526
mod opaque_types;
2627
mod weak_types;
@@ -54,10 +55,15 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
5455
&mut self,
5556
goal: Goal<'tcx, NormalizesTo<'tcx>>,
5657
) -> QueryResult<'tcx> {
57-
match goal.predicate.alias.kind(self.tcx()) {
58+
let tcx = self.tcx();
59+
match goal.predicate.alias.kind(tcx) {
5860
ty::AliasTermKind::ProjectionTy | ty::AliasTermKind::ProjectionConst => {
59-
let candidates = self.assemble_and_evaluate_candidates(goal);
60-
self.merge_candidates(candidates)
61+
if fast_reject::is_rigid_alias(tcx, goal.param_env, goal.predicate.alias) {
62+
return Err(NoSolution);
63+
} else {
64+
let candidates = self.assemble_and_evaluate_candidates(goal);
65+
self.merge_candidates(candidates)
66+
}
6167
}
6268
ty::AliasTermKind::InherentTy => self.normalize_inherent_associated_type(goal),
6369
ty::AliasTermKind::OpaqueTy => self.normalize_opaque_type(goal),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
//@ check-pass
2+
//@ compile-flags: -Znext-solver
3+
//@ ignore-compare-mode-next-solver (explicit revisions)
4+
5+
// Minimization of a hang in rayon, cc trait-solver-refactor-initiative#109
6+
7+
pub trait ParallelIterator {
8+
type Item;
9+
}
10+
11+
macro_rules! multizip_impl {
12+
($($T:ident),+) => {
13+
impl<$( $T, )+> ParallelIterator for ($( $T, )+)
14+
where
15+
$(
16+
$T: ParallelIterator,
17+
$T::Item: ParallelIterator,
18+
)+
19+
{
20+
type Item = ();
21+
}
22+
}
23+
}
24+
25+
multizip_impl! { A, B, C, D, E, F, G, H, I, J, K, L }
26+
27+
fn main() {}

0 commit comments

Comments
 (0)