Skip to content

Structurally resolve projections (but actually) in the new solver #108833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 60 additions & 25 deletions compiler/rustc_borrowck/src/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ use rustc_middle::ty::{
use rustc_span::def_id::CRATE_DEF_ID;
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::VariantIdx;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::query::type_op::custom::scrape_region_constraints;
use rustc_trait_selection::traits::query::type_op::custom::CustomTypeOp;
use rustc_trait_selection::traits::query::type_op::{TypeOp, TypeOpOutput};
use rustc_trait_selection::traits::query::Fallible;
use rustc_trait_selection::traits::PredicateObligation;
use rustc_trait_selection::traits::{fully_solve_obligation, Obligation, ObligationCause};

use rustc_mir_dataflow::impls::MaybeInitializedPlaces;
use rustc_mir_dataflow::move_paths::MoveData;
Expand Down Expand Up @@ -1154,6 +1156,31 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
self.infcx.tcx
}

fn structurally_resolved_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
if self.tcx().trait_solver_next() && let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
let new_infer_ty = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: DUMMY_SP,
});
let obligation = Obligation::new(
self.tcx(),
ObligationCause::dummy(),
self.param_env,
ty::Binder::dummy(ty::ProjectionPredicate {
projection_ty,
term: new_infer_ty.into(),
}),
);

if self.infcx.predicate_may_hold(&obligation) {
fully_solve_obligation(&self.infcx, obligation);
return self.infcx.resolve_vars_if_possible(new_infer_ty);
}
}

ty
}

#[instrument(skip(self, body, location), level = "debug")]
fn check_stmt(&mut self, body: &Body<'tcx>, stmt: &Statement<'tcx>, location: Location) {
let tcx = self.tcx();
Expand Down Expand Up @@ -1871,6 +1898,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
Rvalue::Cast(cast_kind, op, ty) => {
self.check_operand(op, location);

let structurally_resolved_cast_tys = || {
(
self.structurally_resolved_ty(op.ty(body, tcx)),
self.structurally_resolved_ty(*ty),
)
};

match cast_kind {
CastKind::Pointer(PointerCast::ReifyFnPointer) => {
let fn_sig = op.ty(body, tcx).fn_sig(tcx);
Expand Down Expand Up @@ -1902,7 +1936,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}

CastKind::Pointer(PointerCast::ClosureFnPointer(unsafety)) => {
let sig = match op.ty(body, tcx).kind() {
let sig = match self.structurally_resolved_ty(op.ty(body, tcx)).kind() {
ty::Closure(_, substs) => substs.as_closure().sig(),
_ => bug!(),
};
Expand Down Expand Up @@ -1971,10 +2005,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
// get the constraints from the target type (`dyn* Clone`)
//
// apply them to prove that the source type `Foo` implements `Clone` etc
let (existential_predicates, region) = match ty.kind() {
Dynamic(predicates, region, ty::DynStar) => (predicates, region),
_ => panic!("Invalid dyn* cast_ty"),
};
let (existential_predicates, region) =
match self.structurally_resolved_ty(*ty).kind() {
Dynamic(predicates, region, ty::DynStar) => (predicates, region),
_ => panic!("Invalid dyn* cast_ty"),
};

let self_ty = op.ty(body, tcx);

Expand All @@ -2001,7 +2036,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
let ty::RawPtr(ty::TypeAndMut {
ty: ty_from,
mutbl: hir::Mutability::Mut,
}) = op.ty(body, tcx).kind() else {
}) = self.structurally_resolved_ty(op.ty(body, tcx)).kind() else {
span_mirbug!(
self,
rvalue,
Expand All @@ -2013,7 +2048,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
let ty::RawPtr(ty::TypeAndMut {
ty: ty_to,
mutbl: hir::Mutability::Not,
}) = ty.kind() else {
}) = self.structurally_resolved_ty(*ty).kind() else {
span_mirbug!(
self,
rvalue,
Expand Down Expand Up @@ -2042,7 +2077,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
CastKind::Pointer(PointerCast::ArrayToPointer) => {
let ty_from = op.ty(body, tcx);

let opt_ty_elem_mut = match ty_from.kind() {
let opt_ty_elem_mut = match self.structurally_resolved_ty(ty_from).kind() {
ty::RawPtr(ty::TypeAndMut { mutbl: array_mut, ty: array_ty }) => {
match array_ty.kind() {
ty::Array(ty_elem, _) => Some((ty_elem, *array_mut)),
Expand All @@ -2062,7 +2097,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
return;
};

let (ty_to, ty_to_mut) = match ty.kind() {
let (ty_to, ty_to_mut) = match self.structurally_resolved_ty(*ty).kind() {
ty::RawPtr(ty::TypeAndMut { mutbl: ty_to_mut, ty: ty_to }) => {
(ty_to, *ty_to_mut)
}
Expand Down Expand Up @@ -2106,9 +2141,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}

CastKind::PointerExposeAddress => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Ptr(_) | CastTy::FnPtr), Some(CastTy::Int(_))) => (),
_ => {
Expand All @@ -2124,9 +2159,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}

CastKind::PointerFromExposedAddress => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Int(_)), Some(CastTy::Ptr(_))) => (),
_ => {
Expand All @@ -2141,9 +2176,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::IntToInt => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Int(_)), Some(CastTy::Int(_))) => (),
_ => {
Expand All @@ -2158,9 +2193,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::IntToFloat => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Int(_)), Some(CastTy::Float)) => (),
_ => {
Expand All @@ -2175,9 +2210,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::FloatToInt => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Float), Some(CastTy::Int(_))) => (),
_ => {
Expand All @@ -2192,9 +2227,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::FloatToFloat => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Float), Some(CastTy::Float)) => (),
_ => {
Expand All @@ -2209,9 +2244,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::FnPtrToPtr => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::FnPtr), Some(CastTy::Ptr(_))) => (),
_ => {
Expand All @@ -2226,9 +2261,9 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
}
}
CastKind::PtrToPtr => {
let ty_from = op.ty(body, tcx);
let (ty_from, ty) = structurally_resolved_cast_tys();
let cast_ty_from = CastTy::from_ty(ty_from);
let cast_ty_to = CastTy::from_ty(*ty);
let cast_ty_to = CastTy::from_ty(ty);
match (cast_ty_from, cast_ty_to) {
(Some(CastTy::Ptr(_)), Some(CastTy::Ptr(_))) => (),
_ => {
Expand Down
13 changes: 12 additions & 1 deletion compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use rustc_span::{self, BytePos, DesugaringKind, Span};
use rustc_target::spec::abi::Abi;
use rustc_trait_selection::infer::InferCtxtExt as _;
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::{
self, NormalizeExt, ObligationCause, ObligationCauseCode, ObligationCtxt,
};
Expand Down Expand Up @@ -144,12 +145,22 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
debug!("unify(a: {:?}, b: {:?}, use_lub: {})", a, b, self.use_lub);
self.commit_if_ok(|_| {
let at = self.at(&self.cause, self.fcx.param_env).define_opaque_types(true);
if self.use_lub {
let result = if self.use_lub {
at.lub(b, a)
} else {
at.sup(b, a)
.map(|InferOk { value: (), obligations }| InferOk { value: a, obligations })
}?;

if self.tcx.trait_solver_next() {
for obligation in &result.obligations {
if !self.predicate_may_hold(obligation) {
return Err(TypeError::Mismatch);
}
}
}

Ok(result)
})
}

Expand Down
24 changes: 23 additions & 1 deletion compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rustc_hir_analysis::astconv::{
};
use rustc_infer::infer::canonical::{Canonical, OriginalQueryValues, QueryResponse};
use rustc_infer::infer::error_reporting::TypeAnnotationNeeded::E0282;
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
use rustc_infer::infer::InferResult;
use rustc_middle::ty::adjustment::{Adjust, Adjustment, AutoBorrow, AutoBorrowMutability};
use rustc_middle::ty::error::TypeError;
Expand All @@ -34,6 +35,7 @@ use rustc_span::hygiene::DesugaringKind;
use rustc_span::symbol::{kw, sym, Ident};
use rustc_span::Span;
use rustc_trait_selection::traits::error_reporting::TypeErrCtxtExt as _;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;
use rustc_trait_selection::traits::{self, NormalizeExt, ObligationCauseCode, ObligationCtxt};

use std::collections::hash_map::Entry;
Expand Down Expand Up @@ -1408,7 +1410,27 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
/// Resolves `typ` by a single level if `typ` is a type variable.
/// If no resolution is possible, then an error is reported.
/// Numeric inference variables may be left unresolved.
pub fn structurally_resolved_type(&self, sp: Span, ty: Ty<'tcx>) -> Ty<'tcx> {
pub fn structurally_resolved_type(&self, sp: Span, mut ty: Ty<'tcx>) -> Ty<'tcx> {
if self.tcx.trait_solver_next() && let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
let new_infer_ty = self.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: sp,
});
let obligation = traits::Obligation::new(
self.tcx,
self.misc(sp),
self.param_env,
ty::Binder::dummy(ty::ProjectionPredicate {
projection_ty,
term: new_infer_ty.into(),
}),
);
if self.predicate_may_hold(&obligation) {
self.register_predicate(obligation);
ty = new_infer_ty;
}
}

let ty = self.resolve_vars_with_obligations(ty);
if !ty.is_ty_var() {
ty
Expand Down
12 changes: 6 additions & 6 deletions compiler/rustc_mir_build/src/build/expr/as_rvalue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use rustc_middle::mir::AssertKind;
use rustc_middle::mir::Place;
use rustc_middle::mir::*;
use rustc_middle::thir::*;
use rustc_middle::ty::cast::{mir_cast_kind, CastTy};
use rustc_middle::ty::cast::{mir_cast_kind};
use rustc_middle::ty::{self, Ty, UpvarSubsts};
use rustc_span::Span;

Expand Down Expand Up @@ -263,11 +263,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
);
(source, ty)
};
let from_ty = CastTy::from_ty(ty);
let cast_ty = CastTy::from_ty(expr.ty);
debug!("ExprKind::Cast from_ty={from_ty:?}, cast_ty={:?}/{cast_ty:?}", expr.ty,);
let cast_kind = mir_cast_kind(ty, expr.ty);
block.and(Rvalue::Cast(cast_kind, source, expr.ty))

let ty = this.structurally_resolved_ty(ty);
let cast_ty = this.structurally_resolved_ty(expr.ty);
let cast_kind = mir_cast_kind(ty, cast_ty);
block.and(Rvalue::Cast(cast_kind, source, cast_ty))
}
ExprKind::Pointer { cast, source } => {
let source = unpack!(
Expand Down
30 changes: 29 additions & 1 deletion compiler/rustc_mir_build/src/build/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_hir::{GeneratorKind, Node};
use rustc_index::vec::{Idx, IndexVec};
use rustc_infer::infer::type_variable::{TypeVariableOriginKind, TypeVariableOrigin};
use rustc_infer::infer::{InferCtxt, TyCtxtInferExt};
use rustc_middle::hir::place::PlaceBase as HirPlaceBase;
use rustc_middle::middle::region;
Expand All @@ -22,9 +23,11 @@ use rustc_middle::thir::{
};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::symbol::sym;
use rustc_span::Span;
use rustc_span::{Span, DUMMY_SP};
use rustc_span::Symbol;
use rustc_target::spec::abi::Abi;
use rustc_trait_selection::traits;
use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt;

use super::lints;

Expand Down Expand Up @@ -229,6 +232,31 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
fn var_local_id(&self, id: LocalVarId, for_guard: ForGuard) -> Local {
self.var_indices[&id].local_id(for_guard)
}

fn structurally_resolved_ty(&self, ty: Ty<'tcx>) -> Ty<'tcx> {
if self.tcx.trait_solver_next() && let ty::Alias(ty::Projection, projection_ty) = *ty.kind() {
let new_infer_ty = self.infcx.next_ty_var(TypeVariableOrigin {
kind: TypeVariableOriginKind::NormalizeProjectionType,
span: DUMMY_SP,
});
let obligation = traits::Obligation::new(
self.tcx,
traits::ObligationCause::dummy(),
self.param_env,
ty::Binder::dummy(ty::ProjectionPredicate {
projection_ty,
term: new_infer_ty.into(),
}),
);

if self.infcx.predicate_may_hold(&obligation) {
traits::fully_solve_obligation(&self.infcx, obligation);
return self.infcx.resolve_vars_if_possible(new_infer_ty);
}
}

ty
}
}

impl BlockContext {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// compile-flags: -Ztrait-solver=next
// known-bug: unknown
// check-pass

fn main() {
(0u8 + 0u8) as char;
Expand Down

This file was deleted.