Skip to content

Uplift fast rejection to new solver #127146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_hir_typeck/src/method/suggest.rs
Original file line number Diff line number Diff line change
@@ -2129,7 +2129,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
let target_ty = self
.autoderef(sugg_span, rcvr_ty)
.find(|(rcvr_ty, _)| {
DeepRejectCtxt { treat_obligation_params: TreatParams::AsCandidateKey }
DeepRejectCtxt::new(self.tcx, TreatParams::ForLookup)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure this being AsCandidateKey was wrong, since we're trying to match an impl.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AsCandidateKey is strictly weaker than ForLookup as it's used when we'd later replace all params with infer vars.

Using ForLookup is correct here 👍

.types_may_unify(*rcvr_ty, impl_ty)
})
.map_or(impl_ty, |(ty, _)| ty)
11 changes: 0 additions & 11 deletions compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
@@ -373,17 +373,6 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
.map(|assoc_item| assoc_item.def_id)
}

fn args_may_unify_deep(
self,
obligation_args: ty::GenericArgsRef<'tcx>,
impl_args: ty::GenericArgsRef<'tcx>,
) -> bool {
ty::fast_reject::DeepRejectCtxt {
treat_obligation_params: ty::fast_reject::TreatParams::ForLookup,
}
.args_may_unify(obligation_args, impl_args)
}

// This implementation is a bit different from `TyCtxt::for_each_relevant_impl`,
// since we want to skip over blanket impls for non-rigid aliases, and also we
// only want to consider types that *actually* unify with float/int vars.
368 changes: 4 additions & 364 deletions compiler/rustc_middle/src/ty/fast_reject.rs
Original file line number Diff line number Diff line change
@@ -1,369 +1,9 @@
use crate::mir::Mutability;
use crate::ty::GenericArgKind;
use crate::ty::{self, GenericArgsRef, Ty, TyCtxt, TypeVisitableExt};
use rustc_hir::def_id::DefId;
use rustc_macros::{HashStable, TyDecodable, TyEncodable};
use std::fmt::Debug;
use std::hash::Hash;
use std::iter;

/// See `simplify_type`.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, TyEncodable, TyDecodable, HashStable)]
pub enum SimplifiedType {
Bool,
Char,
Int(ty::IntTy),
Uint(ty::UintTy),
Float(ty::FloatTy),
Adt(DefId),
Foreign(DefId),
Str,
Array,
Slice,
Ref(Mutability),
Ptr(Mutability),
Never,
Tuple(usize),
/// A trait object, all of whose components are markers
/// (e.g., `dyn Send + Sync`).
MarkerTraitObject,
Trait(DefId),
Closure(DefId),
Coroutine(DefId),
CoroutineWitness(DefId),
Function(usize),
Placeholder,
Error,
}
use super::TyCtxt;

/// Generic parameters are pretty much just bound variables, e.g.
/// the type of `fn foo<'a, T>(x: &'a T) -> u32 { ... }` can be thought of as
/// `for<'a, T> fn(&'a T) -> u32`.
///
/// Typecheck of `foo` has to succeed for all possible generic arguments, so
/// during typeck, we have to treat its generic parameters as if they
/// were placeholders.
///
/// But when calling `foo` we only have to provide a specific generic argument.
/// In that case the generic parameters are instantiated with inference variables.
/// As we use `simplify_type` before that instantiation happens, we just treat
/// generic parameters as if they were inference variables in that case.
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub enum TreatParams {
/// Treat parameters as infer vars. This is the correct mode for caching
/// an impl's type for lookup.
AsCandidateKey,
/// Treat parameters as placeholders in the given environment. This is the
/// correct mode for *lookup*, as during candidate selection.
///
/// This also treats projections with inference variables as infer vars
/// since they could be further normalized.
ForLookup,
}
pub use rustc_type_ir::fast_reject::*;

/// Tries to simplify a type by only returning the outermost injective¹ layer, if one exists.
///
/// **This function should only be used if you need to store or retrieve the type from some
/// hashmap. If you want to quickly decide whether two types may unify, use the [DeepRejectCtxt]
/// instead.**
///
/// The idea is to get something simple that we can use to quickly decide if two types could unify,
/// for example during method lookup. If this function returns `Some(x)` it can only unify with
/// types for which this method returns either `Some(x)` as well or `None`.
///
/// A special case here are parameters and projections, which are only injective
/// if they are treated as placeholders.
///
/// For example when storing impls based on their simplified self type, we treat
/// generic parameters as if they were inference variables. We must not simplify them here,
/// as they can unify with any other type.
///
/// With projections we have to be even more careful, as treating them as placeholders
/// is only correct if they are fully normalized.
///
/// ¹ meaning that if the outermost layers are different, then the whole types are also different.
pub fn simplify_type<'tcx>(
tcx: TyCtxt<'tcx>,
ty: Ty<'tcx>,
treat_params: TreatParams,
) -> Option<SimplifiedType> {
match *ty.kind() {
ty::Bool => Some(SimplifiedType::Bool),
ty::Char => Some(SimplifiedType::Char),
ty::Int(int_type) => Some(SimplifiedType::Int(int_type)),
ty::Uint(uint_type) => Some(SimplifiedType::Uint(uint_type)),
ty::Float(float_type) => Some(SimplifiedType::Float(float_type)),
ty::Adt(def, _) => Some(SimplifiedType::Adt(def.did())),
ty::Str => Some(SimplifiedType::Str),
ty::Array(..) => Some(SimplifiedType::Array),
ty::Slice(..) => Some(SimplifiedType::Slice),
ty::Pat(ty, ..) => simplify_type(tcx, ty, treat_params),
ty::RawPtr(_, mutbl) => Some(SimplifiedType::Ptr(mutbl)),
ty::Dynamic(trait_info, ..) => match trait_info.principal_def_id() {
Some(principal_def_id) if !tcx.trait_is_auto(principal_def_id) => {
Some(SimplifiedType::Trait(principal_def_id))
}
_ => Some(SimplifiedType::MarkerTraitObject),
},
ty::Ref(_, _, mutbl) => Some(SimplifiedType::Ref(mutbl)),
ty::FnDef(def_id, _) | ty::Closure(def_id, _) | ty::CoroutineClosure(def_id, _) => {
Some(SimplifiedType::Closure(def_id))
}
ty::Coroutine(def_id, _) => Some(SimplifiedType::Coroutine(def_id)),
ty::CoroutineWitness(def_id, _) => Some(SimplifiedType::CoroutineWitness(def_id)),
ty::Never => Some(SimplifiedType::Never),
ty::Tuple(tys) => Some(SimplifiedType::Tuple(tys.len())),
ty::FnPtr(f) => Some(SimplifiedType::Function(f.skip_binder().inputs().len())),
ty::Placeholder(..) => Some(SimplifiedType::Placeholder),
ty::Param(_) => match treat_params {
TreatParams::ForLookup => Some(SimplifiedType::Placeholder),
TreatParams::AsCandidateKey => None,
},
ty::Alias(..) => match treat_params {
// When treating `ty::Param` as a placeholder, projections also
// don't unify with anything else as long as they are fully normalized.
// FIXME(-Znext-solver): Can remove this `if` and always simplify to `Placeholder`
// when the new solver is enabled by default.
TreatParams::ForLookup if !ty.has_non_region_infer() => {
Some(SimplifiedType::Placeholder)
}
TreatParams::ForLookup | TreatParams::AsCandidateKey => None,
},
ty::Foreign(def_id) => Some(SimplifiedType::Foreign(def_id)),
ty::Error(_) => Some(SimplifiedType::Error),
ty::Bound(..) | ty::Infer(_) => None,
}
}
pub type DeepRejectCtxt<'tcx> = rustc_type_ir::fast_reject::DeepRejectCtxt<TyCtxt<'tcx>>;

impl SimplifiedType {
pub fn def(self) -> Option<DefId> {
match self {
SimplifiedType::Adt(d)
| SimplifiedType::Foreign(d)
| SimplifiedType::Trait(d)
| SimplifiedType::Closure(d)
| SimplifiedType::Coroutine(d)
| SimplifiedType::CoroutineWitness(d) => Some(d),
_ => None,
}
}
}

/// Given generic arguments from an obligation and an impl,
/// could these two be unified after replacing parameters in the
/// the impl with inference variables.
///
/// For obligations, parameters won't be replaced by inference
/// variables and only unify with themselves. We treat them
/// the same way we treat placeholders.
///
/// We also use this function during coherence. For coherence the
/// impls only have to overlap for some value, so we treat parameters
/// on both sides like inference variables. This behavior is toggled
/// using the `treat_obligation_params` field.
#[derive(Debug, Clone, Copy)]
pub struct DeepRejectCtxt {
pub treat_obligation_params: TreatParams,
}

impl DeepRejectCtxt {
pub fn args_may_unify<'tcx>(
self,
obligation_args: GenericArgsRef<'tcx>,
impl_args: GenericArgsRef<'tcx>,
) -> bool {
iter::zip(obligation_args, impl_args).all(|(obl, imp)| {
match (obl.unpack(), imp.unpack()) {
// We don't fast reject based on regions.
(GenericArgKind::Lifetime(_), GenericArgKind::Lifetime(_)) => true,
(GenericArgKind::Type(obl), GenericArgKind::Type(imp)) => {
self.types_may_unify(obl, imp)
}
(GenericArgKind::Const(obl), GenericArgKind::Const(imp)) => {
self.consts_may_unify(obl, imp)
}
_ => bug!("kind mismatch: {obl} {imp}"),
}
})
}

pub fn types_may_unify<'tcx>(self, obligation_ty: Ty<'tcx>, impl_ty: Ty<'tcx>) -> bool {
match impl_ty.kind() {
// Start by checking whether the type in the impl may unify with
// pretty much everything. Just return `true` in that case.
ty::Param(_) | ty::Error(_) | ty::Alias(..) => return true,
// These types only unify with inference variables or their own
// variant.
ty::Bool
| ty::Char
| ty::Int(_)
| ty::Uint(_)
| ty::Float(_)
| ty::Adt(..)
| ty::Str
| ty::Array(..)
| ty::Slice(..)
| ty::RawPtr(..)
| ty::Dynamic(..)
| ty::Pat(..)
| ty::Ref(..)
| ty::Never
| ty::Tuple(..)
| ty::FnPtr(..)
| ty::Foreign(..) => debug_assert!(impl_ty.is_known_rigid()),
ty::FnDef(..)
| ty::Closure(..)
| ty::CoroutineClosure(..)
| ty::Coroutine(..)
| ty::CoroutineWitness(..)
| ty::Placeholder(..)
| ty::Bound(..)
| ty::Infer(_) => bug!("unexpected impl_ty: {impl_ty}"),
}

let k = impl_ty.kind();
match *obligation_ty.kind() {
// Purely rigid types, use structural equivalence.
ty::Bool
| ty::Char
| ty::Int(_)
| ty::Uint(_)
| ty::Float(_)
| ty::Str
| ty::Never
| ty::Foreign(_) => obligation_ty == impl_ty,
ty::Ref(_, obl_ty, obl_mutbl) => match k {
&ty::Ref(_, impl_ty, impl_mutbl) => {
obl_mutbl == impl_mutbl && self.types_may_unify(obl_ty, impl_ty)
}
_ => false,
},
ty::Adt(obl_def, obl_args) => match k {
&ty::Adt(impl_def, impl_args) => {
obl_def == impl_def && self.args_may_unify(obl_args, impl_args)
}
_ => false,
},
ty::Pat(obl_ty, _) => {
// FIXME(pattern_types): take pattern into account
matches!(k, &ty::Pat(impl_ty, _) if self.types_may_unify(obl_ty, impl_ty))
}
ty::Slice(obl_ty) => {
matches!(k, &ty::Slice(impl_ty) if self.types_may_unify(obl_ty, impl_ty))
}
ty::Array(obl_ty, obl_len) => match k {
&ty::Array(impl_ty, impl_len) => {
self.types_may_unify(obl_ty, impl_ty)
&& self.consts_may_unify(obl_len, impl_len)
}
_ => false,
},
ty::Tuple(obl) => match k {
&ty::Tuple(imp) => {
obl.len() == imp.len()
&& iter::zip(obl, imp).all(|(obl, imp)| self.types_may_unify(obl, imp))
}
_ => false,
},
ty::RawPtr(obl_ty, obl_mutbl) => match *k {
ty::RawPtr(imp_ty, imp_mutbl) => {
obl_mutbl == imp_mutbl && self.types_may_unify(obl_ty, imp_ty)
}
_ => false,
},
ty::Dynamic(obl_preds, ..) => {
// Ideally we would walk the existential predicates here or at least
// compare their length. But considering that the relevant `Relate` impl
// actually sorts and deduplicates these, that doesn't work.
matches!(k, ty::Dynamic(impl_preds, ..) if
obl_preds.principal_def_id() == impl_preds.principal_def_id()
)
}
ty::FnPtr(obl_sig) => match k {
ty::FnPtr(impl_sig) => {
let ty::FnSig { inputs_and_output, c_variadic, safety, abi } =
obl_sig.skip_binder();
let impl_sig = impl_sig.skip_binder();

abi == impl_sig.abi
&& c_variadic == impl_sig.c_variadic
&& safety == impl_sig.safety
&& inputs_and_output.len() == impl_sig.inputs_and_output.len()
&& iter::zip(inputs_and_output, impl_sig.inputs_and_output)
.all(|(obl, imp)| self.types_may_unify(obl, imp))
}
_ => false,
},

// Impls cannot contain these types as these cannot be named directly.
ty::FnDef(..) | ty::Closure(..) | ty::CoroutineClosure(..) | ty::Coroutine(..) => false,

// Placeholder types don't unify with anything on their own
ty::Placeholder(..) | ty::Bound(..) => false,

// Depending on the value of `treat_obligation_params`, we either
// treat generic parameters like placeholders or like inference variables.
ty::Param(_) => match self.treat_obligation_params {
TreatParams::ForLookup => false,
TreatParams::AsCandidateKey => true,
},

ty::Infer(ty::IntVar(_)) => impl_ty.is_integral(),

ty::Infer(ty::FloatVar(_)) => impl_ty.is_floating_point(),

ty::Infer(_) => true,

// As we're walking the whole type, it may encounter projections
// inside of binders and what not, so we're just going to assume that
// projections can unify with other stuff.
//
// Looking forward to lazy normalization this is the safer strategy anyways.
ty::Alias(..) => true,

ty::Error(_) => true,

ty::CoroutineWitness(..) => {
bug!("unexpected obligation type: {:?}", obligation_ty)
}
}
}

pub fn consts_may_unify(self, obligation_ct: ty::Const<'_>, impl_ct: ty::Const<'_>) -> bool {
let impl_val = match impl_ct.kind() {
ty::ConstKind::Expr(_)
| ty::ConstKind::Param(_)
| ty::ConstKind::Unevaluated(_)
| ty::ConstKind::Error(_) => {
return true;
}
ty::ConstKind::Value(_, impl_val) => impl_val,
ty::ConstKind::Infer(_) | ty::ConstKind::Bound(..) | ty::ConstKind::Placeholder(_) => {
bug!("unexpected impl arg: {:?}", impl_ct)
}
};

match obligation_ct.kind() {
ty::ConstKind::Param(_) => match self.treat_obligation_params {
TreatParams::ForLookup => false,
TreatParams::AsCandidateKey => true,
},

// Placeholder consts don't unify with anything on their own
ty::ConstKind::Placeholder(_) => false,

// As we don't necessarily eagerly evaluate constants,
// they might unify with any value.
ty::ConstKind::Expr(_) | ty::ConstKind::Unevaluated(_) | ty::ConstKind::Error(_) => {
true
}
ty::ConstKind::Value(_, obl_val) => obl_val == impl_val,

ty::ConstKind::Infer(_) => true,

ty::ConstKind::Bound(..) => {
bug!("unexpected obl const: {:?}", obligation_ct)
}
}
}
}
pub type SimplifiedType = rustc_type_ir::fast_reject::SimplifiedType<DefId>;
13 changes: 0 additions & 13 deletions compiler/rustc_middle/src/ty/impls_ty.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@
use crate::middle::region;
use crate::mir;
use crate::ty;
use crate::ty::fast_reject::SimplifiedType;
use rustc_data_structures::fingerprint::Fingerprint;
use rustc_data_structures::fx::FxHashMap;
use rustc_data_structures::stable_hasher::HashingControls;
@@ -57,18 +56,6 @@ where
}
}

impl<'a> ToStableHashKey<StableHashingContext<'a>> for SimplifiedType {
type KeyType = Fingerprint;

#[inline]
fn to_stable_hash_key(&self, hcx: &StableHashingContext<'a>) -> Fingerprint {
let mut hasher = StableHasher::new();
let mut hcx: StableHashingContext<'a> = hcx.clone();
self.hash_stable(&mut hcx, &mut hasher);
hasher.finish()
}
}

impl<'a, 'tcx> HashStable<StableHashingContext<'a>> for ty::GenericArg<'tcx> {
fn hash_stable(&self, hcx: &mut StableHashingContext<'a>, hasher: &mut StableHasher) {
self.unpack().hash_stable(hcx, hasher);
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ mod inherent;
mod opaque_types;
mod weak_types;

use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::Upcast as _;
@@ -144,7 +145,7 @@ where

let goal_trait_ref = goal.predicate.alias.trait_ref(cx);
let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !ecx.cx().args_may_unify_deep(
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup).args_may_unify(
goal.predicate.alias.trait_ref(cx).args,
impl_trait_ref.skip_binder().args,
) {
4 changes: 3 additions & 1 deletion compiler/rustc_next_trait_solver/src/solve/trait_goals.rs
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
use rustc_ast_ir::Movability;
use rustc_type_ir::data_structures::IndexSet;
use rustc_type_ir::fast_reject::{DeepRejectCtxt, TreatParams};
use rustc_type_ir::inherent::*;
use rustc_type_ir::lang_items::TraitSolverLangItem;
use rustc_type_ir::visit::TypeVisitableExt as _;
@@ -46,7 +47,8 @@ where
let cx = ecx.cx();

let impl_trait_ref = cx.impl_trait_ref(impl_def_id);
if !cx.args_may_unify_deep(goal.predicate.trait_ref.args, impl_trait_ref.skip_binder().args)
if !DeepRejectCtxt::new(ecx.cx(), TreatParams::ForLookup)
.args_may_unify(goal.predicate.trait_ref.args, impl_trait_ref.skip_binder().args)
{
return Err(NoSolution);
}
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/traits/coherence.rs
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ pub fn overlapping_impls(
// Before doing expensive operations like entering an inference context, do
// a quick check via fast_reject to tell if the impl headers could possibly
// unify.
let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::AsCandidateKey };
let drcx = DeepRejectCtxt::new(tcx, TreatParams::AsCandidateKey);
let impl1_ref = tcx.impl_trait_ref(impl1_def_id);
let impl2_ref = tcx.impl_trait_ref(impl2_def_id);
let may_overlap = match (impl1_ref, impl2_ref) {
Original file line number Diff line number Diff line change
@@ -571,7 +571,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
return;
}

let drcx = DeepRejectCtxt { treat_obligation_params: TreatParams::ForLookup };
let drcx = DeepRejectCtxt::new(self.tcx(), TreatParams::ForLookup);
let obligation_args = obligation.predicate.skip_binder().trait_ref.args;
self.tcx().for_each_relevant_impl(
obligation.predicate.def_id(),
397 changes: 397 additions & 0 deletions compiler/rustc_type_ir/src/fast_reject.rs

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions compiler/rustc_type_ir/src/inherent.rs
Original file line number Diff line number Diff line change
@@ -120,6 +120,14 @@ pub trait Ty<I: Interner<Ty = Self>>:
matches!(self.kind(), ty::Infer(ty::TyVar(_)))
}

fn is_floating_point(self) -> bool {
matches!(self.kind(), ty::Float(_) | ty::Infer(ty::FloatVar(_)))
}

fn is_integral(self) -> bool {
matches!(self.kind(), ty::Infer(ty::IntVar(_)) | ty::Int(_) | ty::Uint(_))
}

fn is_fn_ptr(self) -> bool {
matches!(self.kind(), ty::FnPtr(_))
}
7 changes: 0 additions & 7 deletions compiler/rustc_type_ir/src/interner.rs
Original file line number Diff line number Diff line change
@@ -222,13 +222,6 @@ pub trait Interner:

fn associated_type_def_ids(self, def_id: Self::DefId) -> impl IntoIterator<Item = Self::DefId>;

// FIXME: move `fast_reject` into `rustc_type_ir`.
fn args_may_unify_deep(
self,
obligation_args: Self::GenericArgs,
impl_args: Self::GenericArgs,
) -> bool;

fn for_each_relevant_impl(
self,
trait_def_id: Self::DefId,
1 change: 1 addition & 0 deletions compiler/rustc_type_ir/src/lib.rs
Original file line number Diff line number Diff line change
@@ -21,6 +21,7 @@ pub mod visit;
pub mod codec;
pub mod data_structures;
pub mod error;
pub mod fast_reject;
pub mod fold;
pub mod inherent;
pub mod ir_print;
3 changes: 1 addition & 2 deletions src/librustdoc/html/render/write_shared.rs
Original file line number Diff line number Diff line change
@@ -507,8 +507,7 @@ else if (window.initSearch) window.initSearch(searchIndex);
// Be aware of `tests/rustdoc/type-alias/deeply-nested-112515.rs` which might regress.
let Some(impl_did) = impl_item_id.as_def_id() else { continue };
let for_ty = self.cx.tcx().type_of(impl_did).skip_binder();
let reject_cx =
DeepRejectCtxt { treat_obligation_params: TreatParams::AsCandidateKey };
let reject_cx = DeepRejectCtxt::new(self.cx.tcx(), TreatParams::AsCandidateKey);
if !reject_cx.types_may_unify(aliased_ty, for_ty) {
continue;
}