Skip to content

Commit 4ce2123

Browse files
yeet ConstInferUnifier
1 parent c270b0a commit 4ce2123

File tree

3 files changed

+81
-180
lines changed

3 files changed

+81
-180
lines changed

compiler/rustc_infer/src/infer/combine.rs

+17-160
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,17 @@ use super::equate::Equate;
2626
use super::glb::Glb;
2727
use super::lub::Lub;
2828
use super::sub::Sub;
29-
use super::type_variable::TypeVariableValue;
30-
use super::{DefineOpaqueTypes, InferCtxt, MiscVariable, TypeTrace};
31-
use crate::infer::generalize::{generalize, CombineDelegate, Generalization};
29+
use super::{DefineOpaqueTypes, InferCtxt, TypeTrace};
30+
use crate::infer::generalize::{self, CombineDelegate, Generalization};
3231
use crate::traits::{Obligation, PredicateObligations};
3332
use rustc_middle::infer::canonical::OriginalQueryValues;
3433
use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
3534
use rustc_middle::infer::unify_key::{ConstVariableOrigin, ConstVariableOriginKind};
3635
use rustc_middle::ty::error::{ExpectedFound, TypeError};
3736
use rustc_middle::ty::relate::{RelateResult, TypeRelation};
38-
use rustc_middle::ty::{
39-
self, AliasKind, FallibleTypeFolder, InferConst, ToPredicate, Ty, TyCtxt, TypeFoldable,
40-
TypeSuperFoldable, TypeVisitableExt,
41-
};
37+
use rustc_middle::ty::{self, AliasKind, InferConst, ToPredicate, Ty, TyCtxt, TypeVisitableExt};
4238
use rustc_middle::ty::{IntType, UintType};
43-
use rustc_span::{Span, DUMMY_SP};
39+
use rustc_span::DUMMY_SP;
4440

4541
#[derive(Clone)]
4642
pub struct CombineFields<'infcx, 'tcx> {
@@ -208,11 +204,11 @@ impl<'tcx> InferCtxt<'tcx> {
208204
// matching in the solver.
209205
let a_error = self.tcx.const_error(a.ty(), guar);
210206
if let ty::ConstKind::Infer(InferConst::Var(vid)) = a.kind() {
211-
return self.unify_const_variable(vid, a_error);
207+
return self.unify_const_variable(vid, a_error, relation.param_env());
212208
}
213209
let b_error = self.tcx.const_error(b.ty(), guar);
214210
if let ty::ConstKind::Infer(InferConst::Var(vid)) = b.kind() {
215-
return self.unify_const_variable(vid, b_error);
211+
return self.unify_const_variable(vid, b_error, relation.param_env());
216212
}
217213

218214
return Ok(if relation.a_is_expected() { a_error } else { b_error });
@@ -234,11 +230,11 @@ impl<'tcx> InferCtxt<'tcx> {
234230
}
235231

236232
(ty::ConstKind::Infer(InferConst::Var(vid)), _) => {
237-
return self.unify_const_variable(vid, b);
233+
return self.unify_const_variable(vid, b, relation.param_env());
238234
}
239235

240236
(_, ty::ConstKind::Infer(InferConst::Var(vid))) => {
241-
return self.unify_const_variable(vid, a);
237+
return self.unify_const_variable(vid, a, relation.param_env());
242238
}
243239
(ty::ConstKind::Unevaluated(..), _) | (_, ty::ConstKind::Unevaluated(..))
244240
if self.tcx.lazy_normalization() =>
@@ -291,24 +287,17 @@ impl<'tcx> InferCtxt<'tcx> {
291287
&self,
292288
target_vid: ty::ConstVid<'tcx>,
293289
ct: ty::Const<'tcx>,
290+
param_env: ty::ParamEnv<'tcx>,
294291
) -> RelateResult<'tcx, ty::Const<'tcx>> {
295-
let (for_universe, span) = {
296-
let mut inner = self.inner.borrow_mut();
297-
let variable_table = &mut inner.const_unification_table();
298-
let var_value = variable_table.probe_value(target_vid);
299-
match var_value.val {
300-
ConstVariableValue::Known { value } => {
301-
bug!("instantiating {:?} which has a known value {:?}", target_vid, value)
302-
}
303-
ConstVariableValue::Unknown { universe } => (universe, var_value.origin.span),
304-
}
305-
};
306-
let value = ct.try_fold_with(&mut ConstInferUnifier {
307-
infcx: self,
308-
span,
309-
for_universe,
292+
let span =
293+
self.inner.borrow_mut().const_unification_table().probe_value(target_vid).origin.span;
294+
let Generalization { value, needs_wf: _ } = generalize::generalize(
295+
self,
296+
&mut CombineDelegate { infcx: self, span, param_env },
297+
ct,
310298
target_vid,
311-
})?;
299+
ty::Variance::Invariant,
300+
)?;
312301

313302
self.inner.borrow_mut().const_unification_table().union_value(
314303
target_vid,
@@ -547,135 +536,3 @@ fn float_unification_error<'tcx>(
547536
let (ty::FloatVarValue(a), ty::FloatVarValue(b)) = v;
548537
TypeError::FloatMismatch(ExpectedFound::new(a_is_expected, a, b))
549538
}
550-
551-
struct ConstInferUnifier<'cx, 'tcx> {
552-
infcx: &'cx InferCtxt<'tcx>,
553-
554-
span: Span,
555-
556-
for_universe: ty::UniverseIndex,
557-
558-
/// The vid of the const variable that is in the process of being
559-
/// instantiated; if we find this within the const we are folding,
560-
/// that means we would have created a cyclic const.
561-
target_vid: ty::ConstVid<'tcx>,
562-
}
563-
564-
impl<'tcx> FallibleTypeFolder<TyCtxt<'tcx>> for ConstInferUnifier<'_, 'tcx> {
565-
type Error = TypeError<'tcx>;
566-
567-
fn interner(&self) -> TyCtxt<'tcx> {
568-
self.infcx.tcx
569-
}
570-
571-
#[instrument(level = "debug", skip(self), ret)]
572-
fn try_fold_ty(&mut self, t: Ty<'tcx>) -> Result<Ty<'tcx>, TypeError<'tcx>> {
573-
match t.kind() {
574-
&ty::Infer(ty::TyVar(vid)) => {
575-
let vid = self.infcx.inner.borrow_mut().type_variables().root_var(vid);
576-
let probe = self.infcx.inner.borrow_mut().type_variables().probe(vid);
577-
match probe {
578-
TypeVariableValue::Known { value: u } => {
579-
debug!("ConstOccursChecker: known value {:?}", u);
580-
u.try_fold_with(self)
581-
}
582-
TypeVariableValue::Unknown { universe } => {
583-
if self.for_universe.can_name(universe) {
584-
return Ok(t);
585-
}
586-
587-
let origin =
588-
*self.infcx.inner.borrow_mut().type_variables().var_origin(vid);
589-
let new_var_id = self
590-
.infcx
591-
.inner
592-
.borrow_mut()
593-
.type_variables()
594-
.new_var(self.for_universe, origin);
595-
Ok(self.interner().mk_ty_var(new_var_id))
596-
}
597-
}
598-
}
599-
ty::Infer(ty::IntVar(_) | ty::FloatVar(_)) => Ok(t),
600-
_ => t.try_super_fold_with(self),
601-
}
602-
}
603-
604-
#[instrument(level = "debug", skip(self), ret)]
605-
fn try_fold_region(
606-
&mut self,
607-
r: ty::Region<'tcx>,
608-
) -> Result<ty::Region<'tcx>, TypeError<'tcx>> {
609-
debug!("ConstInferUnifier: r={:?}", r);
610-
611-
match *r {
612-
// Never make variables for regions bound within the type itself,
613-
// nor for erased regions.
614-
ty::ReLateBound(..) | ty::ReErased | ty::ReError(_) => {
615-
return Ok(r);
616-
}
617-
618-
ty::RePlaceholder(..)
619-
| ty::ReVar(..)
620-
| ty::ReStatic
621-
| ty::ReEarlyBound(..)
622-
| ty::ReFree(..) => {
623-
// see common code below
624-
}
625-
}
626-
627-
let r_universe = self.infcx.universe_of_region(r);
628-
if self.for_universe.can_name(r_universe) {
629-
return Ok(r);
630-
} else {
631-
// FIXME: This is non-ideal because we don't give a
632-
// very descriptive origin for this region variable.
633-
Ok(self.infcx.next_region_var_in_universe(MiscVariable(self.span), self.for_universe))
634-
}
635-
}
636-
637-
#[instrument(level = "debug", skip(self), ret)]
638-
fn try_fold_const(&mut self, c: ty::Const<'tcx>) -> Result<ty::Const<'tcx>, TypeError<'tcx>> {
639-
match c.kind() {
640-
ty::ConstKind::Infer(InferConst::Var(vid)) => {
641-
// Check if the current unification would end up
642-
// unifying `target_vid` with a const which contains
643-
// an inference variable which is unioned with `target_vid`.
644-
//
645-
// Not doing so can easily result in stack overflows.
646-
if self
647-
.infcx
648-
.inner
649-
.borrow_mut()
650-
.const_unification_table()
651-
.unioned(self.target_vid, vid)
652-
{
653-
return Err(TypeError::CyclicConst(c));
654-
}
655-
656-
let var_value =
657-
self.infcx.inner.borrow_mut().const_unification_table().probe_value(vid);
658-
match var_value.val {
659-
ConstVariableValue::Known { value: u } => u.try_fold_with(self),
660-
ConstVariableValue::Unknown { universe } => {
661-
if self.for_universe.can_name(universe) {
662-
Ok(c)
663-
} else {
664-
let new_var_id =
665-
self.infcx.inner.borrow_mut().const_unification_table().new_key(
666-
ConstVarValue {
667-
origin: var_value.origin,
668-
val: ConstVariableValue::Unknown {
669-
universe: self.for_universe,
670-
},
671-
},
672-
);
673-
Ok(self.interner().mk_const(new_var_id, c.ty()))
674-
}
675-
}
676-
}
677-
}
678-
_ => c.try_super_fold_with(self),
679-
}
680-
}
681-
}

compiler/rustc_infer/src/infer/generalize.rs

+46-20
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,44 @@ use rustc_hir::def_id::DefId;
33
use rustc_middle::infer::unify_key::{ConstVarValue, ConstVariableValue};
44
use rustc_middle::ty::error::TypeError;
55
use rustc_middle::ty::relate::{self, Relate, RelateResult, TypeRelation};
6-
use rustc_middle::ty::{self, InferConst, Ty, TyCtxt, TypeVisitableExt};
6+
use rustc_middle::ty::{self, InferConst, Term, Ty, TyCtxt, TypeVisitableExt};
77
use rustc_span::Span;
88

99
use crate::infer::nll_relate::TypeRelatingDelegate;
1010
use crate::infer::type_variable::TypeVariableValue;
1111
use crate::infer::{InferCtxt, RegionVariableOrigin};
1212

13-
pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>>(
13+
pub(super) fn generalize<'tcx, D: GeneralizerDelegate<'tcx>, T: Into<Term<'tcx>> + Relate<'tcx>>(
1414
infcx: &InferCtxt<'tcx>,
1515
delegate: &mut D,
16-
ty: Ty<'tcx>,
17-
for_vid: ty::TyVid,
16+
term: T,
17+
for_vid: impl Into<ty::TermVid<'tcx>>,
1818
ambient_variance: ty::Variance,
19-
) -> RelateResult<'tcx, Generalization<Ty<'tcx>>> {
20-
let for_universe = infcx.probe_ty_var(for_vid).unwrap_err();
21-
let for_vid_sub_root = infcx.inner.borrow_mut().type_variables().sub_root_var(for_vid);
19+
) -> RelateResult<'tcx, Generalization<T>> {
20+
let (for_universe, root_vid) = match for_vid.into() {
21+
ty::TermVid::Ty(ty_vid) => (
22+
infcx.probe_ty_var(ty_vid).unwrap_err(),
23+
ty::TermVid::Ty(infcx.inner.borrow_mut().type_variables().sub_root_var(ty_vid)),
24+
),
25+
ty::TermVid::Const(ct_vid) => (
26+
infcx.probe_const_var(ct_vid).unwrap_err(),
27+
ty::TermVid::Const(infcx.inner.borrow_mut().const_unification_table().find(ct_vid)),
28+
),
29+
};
2230

2331
let mut generalizer = Generalizer {
2432
infcx,
2533
delegate,
2634
ambient_variance,
27-
for_vid_sub_root,
35+
root_vid,
2836
for_universe,
29-
root_ty: ty,
37+
root_term: term.into(),
3038
needs_wf: false,
3139
cache: Default::default(),
3240
};
3341

34-
assert!(!ty.has_escaping_bound_vars());
35-
let value = generalizer.relate(ty, ty)?;
42+
assert!(!term.has_escaping_bound_vars());
43+
let value = generalizer.relate(term, term)?;
3644
let needs_wf = generalizer.needs_wf;
3745
Ok(Generalization { value, needs_wf })
3846
}
@@ -99,11 +107,8 @@ where
99107
/// establishes `'0: 'x` as a constraint.
100108
///
101109
/// [blog post]: https://is.gd/0hKvIr
102-
struct Generalizer<'me, 'tcx, D>
103-
where
104-
D: GeneralizerDelegate<'tcx>,
105-
{
106-
pub infcx: &'me InferCtxt<'tcx>,
110+
struct Generalizer<'me, 'tcx, D> {
111+
infcx: &'me InferCtxt<'tcx>,
107112

108113
// An delegate used to abstract the behaviors of the three previous
109114
// generalizer-like implementations.
@@ -116,21 +121,31 @@ where
116121
/// The vid of the type variable that is in the process of being
117122
/// instantiated. If we find this within the value we are folding,
118123
/// that means we would have created a cyclic value.
119-
pub for_vid_sub_root: ty::TyVid,
124+
root_vid: ty::TermVid<'tcx>,
120125

121126
/// The universe of the type variable that is in the process of being
122127
/// instantiated. If we find anything that this universe cannot name,
123128
/// we reject the relation.
124129
for_universe: ty::UniverseIndex,
125130

126-
pub root_ty: Ty<'tcx>,
131+
/// The root term (const or type) we're generalizing. Used for cycle errors.
132+
root_term: Term<'tcx>,
127133

128134
cache: SsoHashMap<Ty<'tcx>, Ty<'tcx>>,
129135

130136
/// See the field `needs_wf` in `Generalization`.
131137
needs_wf: bool,
132138
}
133139

140+
impl<'tcx, D> Generalizer<'_, 'tcx, D> {
141+
fn cyclic_term_error(&self) -> TypeError<'tcx> {
142+
match self.root_term.unpack() {
143+
ty::TermKind::Ty(ty) => TypeError::CyclicTy(ty),
144+
ty::TermKind::Const(ct) => TypeError::CyclicConst(ct),
145+
}
146+
}
147+
}
148+
134149
impl<'tcx, D> TypeRelation<'tcx> for Generalizer<'_, 'tcx, D>
135150
where
136151
D: GeneralizerDelegate<'tcx>,
@@ -226,10 +241,10 @@ where
226241
let mut inner = self.infcx.inner.borrow_mut();
227242
let vid = inner.type_variables().root_var(vid);
228243
let sub_vid = inner.type_variables().sub_root_var(vid);
229-
if sub_vid == self.for_vid_sub_root {
244+
if TermVid::Ty(sub_vid) == self.root_vid {
230245
// If sub-roots are equal, then `for_vid` and
231246
// `vid` are related via subtyping.
232-
Err(TypeError::CyclicTy(self.root_ty))
247+
Err(self.cyclic_term_error())
233248
} else {
234249
let probe = inner.type_variables().probe(vid);
235250
match probe {
@@ -363,6 +378,17 @@ where
363378
bug!("unexpected inference variable encountered in NLL generalization: {:?}", c);
364379
}
365380
ty::ConstKind::Infer(InferConst::Var(vid)) => {
381+
// Check if the current unification would end up
382+
// unifying `target_vid` with a const which contains
383+
// an inference variable which is unioned with `target_vid`.
384+
//
385+
// Not doing so can easily result in stack overflows.
386+
if TermVid::Const(self.infcx.inner.borrow_mut().const_unification_table().find(vid))
387+
== self.root_vid
388+
{
389+
return Err(self.cyclic_term_error());
390+
}
391+
366392
let mut inner = self.infcx.inner.borrow_mut();
367393
let variable_table = &mut inner.const_unification_table();
368394
let var_value = variable_table.probe_value(vid);

compiler/rustc_middle/src/ty/mod.rs

+18
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,24 @@ impl ParamTerm {
10701070
}
10711071
}
10721072

1073+
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
1074+
pub enum TermVid<'tcx> {
1075+
Ty(ty::TyVid),
1076+
Const(ty::ConstVid<'tcx>),
1077+
}
1078+
1079+
impl From<ty::TyVid> for TermVid<'_> {
1080+
fn from(value: ty::TyVid) -> Self {
1081+
TermVid::Ty(value)
1082+
}
1083+
}
1084+
1085+
impl<'tcx> From<ty::ConstVid<'tcx>> for TermVid<'tcx> {
1086+
fn from(value: ty::ConstVid<'tcx>) -> Self {
1087+
TermVid::Const(value)
1088+
}
1089+
}
1090+
10731091
/// This kind of predicate has no *direct* correspondent in the
10741092
/// syntax, but it roughly corresponds to the syntactic forms:
10751093
///

0 commit comments

Comments
 (0)