Skip to content

Generate obligations when possible instead of rejecting with ambiguity #139955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions compiler/rustc_infer/src/infer/at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,34 @@ impl<'a, 'tcx> At<'a, 'tcx> {
}
}

// FIXME(arbitrary_self_types): remove this interface
// when the new solver is stabilised.
/// Almost like `eq_trace` except this type relating procedure will
/// also generate the obligations arising from equating projection
/// candidates.
pub fn eq_with_proj<T>(
self,
define_opaque_types: DefineOpaqueTypes,
expected: T,
actual: T,
) -> InferResult<'tcx, ()>
where
T: ToTrace<'tcx>,
{
assert!(!self.infcx.next_trait_solver);
let trace = ToTrace::to_trace(self.cause, expected, actual);
let mut op = TypeRelating::new(
self.infcx,
trace,
self.param_env,
define_opaque_types,
ty::Invariant,
)
.through_projections(true);
op.relate(expected, actual)?;
Ok(InferOk { value: (), obligations: op.into_obligations() })
}

pub fn relate<T>(
self,
define_opaque_types: DefineOpaqueTypes,
Expand Down Expand Up @@ -369,6 +397,18 @@ impl<'tcx> ToTrace<'tcx> for ty::Term<'tcx> {
}
}

impl<'tcx> ToTrace<'tcx> for ty::Binder<'tcx, ty::Term<'tcx>> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace {
cause: cause.clone(),
values: ValuePairs::Terms(ExpectedFound {
expected: a.skip_binder(),
found: b.skip_binder(),
}),
}
}
}

impl<'tcx> ToTrace<'tcx> for ty::TraitRef<'tcx> {
fn to_trace(cause: &ObligationCause<'tcx>, a: Self, b: Self) -> TypeTrace<'tcx> {
TypeTrace { cause: cause.clone(), values: ValuePairs::TraitRefs(ExpectedFound::new(a, b)) }
Expand Down
42 changes: 41 additions & 1 deletion compiler/rustc_infer/src/infer/relate/type_relating.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ pub(crate) struct TypeRelating<'infcx, 'tcx> {
param_env: ty::ParamEnv<'tcx>,
define_opaque_types: DefineOpaqueTypes,

/// This indicates whether the relation should
/// report obligations arising from equating aliasing terms
/// involving associated types, instead of rejection.
through_projections: bool,

// Mutable fields.
ambient_variance: ty::Variance,
obligations: PredicateObligations<'tcx>,
Expand Down Expand Up @@ -67,9 +72,15 @@ impl<'infcx, 'tcx> TypeRelating<'infcx, 'tcx> {
ambient_variance,
obligations: PredicateObligations::new(),
cache: Default::default(),
through_projections: false,
}
}

pub(crate) fn through_projections(mut self, walk_through: bool) -> Self {
self.through_projections = walk_through;
self
}

pub(crate) fn into_obligations(self) -> PredicateObligations<'tcx> {
self.obligations
}
Expand Down Expand Up @@ -128,6 +139,7 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, 'tcx> {
if self.cache.contains(&(self.ambient_variance, a, b)) {
return Ok(a);
}
let mut relate_result = a;

match (a.kind(), b.kind()) {
(&ty::Infer(TyVar(a_id)), &ty::Infer(TyVar(b_id))) => {
Expand Down Expand Up @@ -201,14 +213,42 @@ impl<'tcx> TypeRelation<TyCtxt<'tcx>> for TypeRelating<'_, 'tcx> {
)?);
}

(
ty::Alias(ty::Projection | ty::Opaque, _),
ty::Alias(ty::Projection | ty::Opaque, _),
) => {
super_combine_tys(infcx, self, a, b)?;
}

(&ty::Alias(ty::Projection, ty::AliasTy { def_id, args, .. }), _)
if matches!(self.ambient_variance, ty::Variance::Invariant)
&& self.through_projections =>
{
self.register_predicates([ty::ProjectionPredicate {
projection_term: ty::AliasTerm::new(self.cx(), def_id, args),
term: b.into(),
}]);
relate_result = b;
}

(_, &ty::Alias(ty::Projection, ty::AliasTy { def_id, args, .. }))
if matches!(self.ambient_variance, ty::Variance::Invariant)
&& self.through_projections =>
{
self.register_predicates([ty::ProjectionPredicate {
projection_term: ty::AliasTerm::new(self.cx(), def_id, args),
term: a.into(),
}]);
}

_ => {
super_combine_tys(infcx, self, a, b)?;
}
}

assert!(self.cache.insert((self.ambient_variance, a, b)));

Ok(a)
Ok(relate_result)
}

#[instrument(skip(self), level = "trace")]
Expand Down
82 changes: 56 additions & 26 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Code for projecting associated types out of trait references.

use std::iter::Extend;
use std::ops::ControlFlow;

use rustc_data_structures::sso::SsoHashSet;
Expand Down Expand Up @@ -65,6 +66,7 @@ enum ProjectionCandidate<'tcx> {
Select(Selection<'tcx>),
}

#[derive(Debug)]
enum ProjectionCandidateSet<'tcx> {
None,
Single(ProjectionCandidate<'tcx>),
Expand Down Expand Up @@ -648,15 +650,26 @@ fn project<'cx, 'tcx>(
}

let mut candidates = ProjectionCandidateSet::None;
let mut derived_obligations = PredicateObligations::default();

// Make sure that the following procedures are kept in order. ParamEnv
// needs to be first because it has highest priority, and Select checks
// the return value of push_candidate which assumes it's ran at last.
assemble_candidates_from_param_env(selcx, obligation, &mut candidates);
assemble_candidates_from_param_env(
selcx,
obligation,
&mut candidates,
&mut derived_obligations,
);

assemble_candidates_from_trait_def(selcx, obligation, &mut candidates);

assemble_candidates_from_object_ty(selcx, obligation, &mut candidates);
assemble_candidates_from_object_ty(
selcx,
obligation,
&mut candidates,
&mut derived_obligations,
);

if let ProjectionCandidateSet::Single(ProjectionCandidate::Object(_)) = candidates {
// Avoid normalization cycle from selection (see
Expand All @@ -669,7 +682,13 @@ fn project<'cx, 'tcx>(

match candidates {
ProjectionCandidateSet::Single(candidate) => {
confirm_candidate(selcx, obligation, candidate)
confirm_candidate(selcx, obligation, candidate).map(move |proj| {
if let Projected::Progress(progress) = proj {
Projected::Progress(progress.with_addl_obligations(derived_obligations))
} else {
proj
}
})
}
ProjectionCandidateSet::None => {
let tcx = selcx.tcx();
Expand All @@ -691,14 +710,15 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTermObligation<'tcx>,
candidate_set: &mut ProjectionCandidateSet<'tcx>,
derived_obligations: &mut impl Extend<PredicateObligation<'tcx>>,
) {
assemble_candidates_from_predicates(
selcx,
obligation,
candidate_set,
ProjectionCandidate::ParamEnv,
obligation.param_env.caller_bounds().iter(),
false,
derived_obligations,
);
}

Expand All @@ -712,6 +732,7 @@ fn assemble_candidates_from_param_env<'cx, 'tcx>(
/// ```
///
/// Here, for example, we could conclude that the result is `i32`.
#[instrument(level = "debug", skip(selcx))]
fn assemble_candidates_from_trait_def<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTermObligation<'tcx>,
Expand Down Expand Up @@ -774,6 +795,7 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTermObligation<'tcx>,
candidate_set: &mut ProjectionCandidateSet<'tcx>,
derived_obligations: &mut impl Extend<PredicateObligation<'tcx>>,
) {
debug!("assemble_candidates_from_object_ty(..)");

Expand Down Expand Up @@ -806,21 +828,18 @@ fn assemble_candidates_from_object_ty<'cx, 'tcx>(
candidate_set,
ProjectionCandidate::Object,
env_predicates,
false,
derived_obligations,
);
}

#[instrument(
level = "debug",
skip(selcx, candidate_set, ctor, env_predicates, potentially_unnormalized_candidates)
)]
#[instrument(level = "debug", skip(selcx, env_predicates, derived_obligations))]
fn assemble_candidates_from_predicates<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTermObligation<'tcx>,
candidate_set: &mut ProjectionCandidateSet<'tcx>,
ctor: fn(ty::PolyProjectionPredicate<'tcx>) -> ProjectionCandidate<'tcx>,
env_predicates: impl Iterator<Item = ty::Clause<'tcx>>,
potentially_unnormalized_candidates: bool,
derived_obligations: &mut impl Extend<PredicateObligation<'tcx>>,
) {
let infcx = selcx.infcx;
let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
Expand All @@ -838,28 +857,39 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
continue;
}

let is_match = infcx.probe(|_| {
selcx.match_projection_projections(
obligation,
data,
potentially_unnormalized_candidates,
)
});
let is_match =
infcx.probe(|_| selcx.match_projection_projections(obligation, data, false));

match is_match {
ProjectionMatchesProjection::Yes => {
candidate_set.push_candidate(ctor(data));

if potentially_unnormalized_candidates
&& !obligation.predicate.has_non_region_infer()
debug!(?data, "push");
if let ProjectionCandidateSet::Single(
ProjectionCandidate::ParamEnv(proj)
| ProjectionCandidate::Object(proj)
| ProjectionCandidate::TraitDef(proj),
) = candidate_set
{
// HACK: Pick the first trait def candidate for a fully
// inferred predicate. This is to allow duplicates that
// differ only in normalization.
return;
match infcx.commit_if_ok(|_| {
infcx.at(&obligation.cause, obligation.param_env).eq_with_proj(
DefineOpaqueTypes::No,
data.term(),
proj.term(),
)
}) {
Ok(InferOk { value: (), obligations }) => {
derived_obligations.extend(obligations);
}
Err(e) => {
debug!(?e, "refuse to unify candidates");
candidate_set.push_candidate(ctor(data));
}
}
} else {
candidate_set.push_candidate(ctor(data));
}
}
ProjectionMatchesProjection::Ambiguous => {
debug!("mark ambiguous");
candidate_set.mark_ambiguous();
}
ProjectionMatchesProjection::No => {}
Expand All @@ -868,7 +898,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
}
}

#[instrument(level = "debug", skip(selcx, obligation, candidate_set))]
#[instrument(level = "debug", skip(selcx))]
fn assemble_candidates_from_impls<'cx, 'tcx>(
selcx: &mut SelectionContext<'cx, 'tcx>,
obligation: &ProjectionTermObligation<'tcx>,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//@ revisions: traditional next_solver
//@ [next_solver] compile-flags: -Znext-solver
//@ check-pass

use std::marker::PhantomData;

pub trait Receiver {
type Target: ?Sized;
}

pub trait Deref: Receiver<Target = <Self as Deref>::Target> {
type Target: ?Sized;
fn deref(&self) -> &<Self as Deref>::Target;
}

impl<T: Deref> Receiver for T {
type Target = <T as Deref>::Target;
}

// ===
pub struct Type<Id, T>(PhantomData<(Id, T)>);
pub struct AliasRef<Id, T: TypePtr<Id = Id>>(PhantomData<(Id, T)>);

pub trait TypePtr: Deref<Target = Type<<Self as TypePtr>::Id, Self>> + Sized {
// ^ the impl head here provides the first candidate
// <T as Deref>::Target := Type<<T as TypePtr>::Id>
type Id;
}

pub struct Alias<Id, T>(PhantomData<(Id, T)>);

impl<Id, T> Deref for Alias<Id, T>
where
T: TypePtr<Id = Id> + Deref<Target = Type<Id, T>>,
// ^ the impl head here provides the second candidate
// <T as Deref>::Target := Type<Id, T>
// and additionally a normalisation is mandatory due to
// the following supertrait relation trait
// Deref: Receiver<Target = <Self as Deref>::Target>
{
type Target = AliasRef<Id, T>;

fn deref(&self) -> &<Self as Deref>::Target {
todo!()
}
}

fn main() {}
Loading