Skip to content

Commit 5742b9a

Browse files
committed
Move THIR pattern traversal out of Pat and into Thir
1 parent 2484498 commit 5742b9a

File tree

4 files changed

+38
-28
lines changed

4 files changed

+38
-28
lines changed

compiler/rustc_middle/src/thir.rs

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -647,11 +647,17 @@ impl<'tcx> Pat<'tcx> {
647647
_ => None,
648648
}
649649
}
650+
}
650651

652+
impl<'tcx> Thir<'tcx> {
651653
/// Call `f` on every "binding" in a pattern, e.g., on `a` in
652654
/// `match foo() { Some(a) => (), None => () }`
653-
pub fn each_binding(&self, mut f: impl FnMut(Symbol, ByRef, Ty<'tcx>, Span)) {
654-
self.walk_always(|p| {
655+
pub fn each_pat_binding(
656+
&self,
657+
pat: &Pat<'tcx>,
658+
mut f: impl FnMut(Symbol, ByRef, Ty<'tcx>, Span),
659+
) {
660+
self.walk_pat_always(pat, |p| {
655661
if let PatKind::Binding { name, mode, ty, .. } = p.kind {
656662
f(name, mode.0, ty, p.span);
657663
}
@@ -661,17 +667,17 @@ impl<'tcx> Pat<'tcx> {
661667
/// Walk the pattern in left-to-right order.
662668
///
663669
/// If `it(pat)` returns `false`, the children are not visited.
664-
pub fn walk(&self, mut it: impl FnMut(&Pat<'tcx>) -> bool) {
665-
self.walk_(&mut it)
670+
pub fn walk_pat(&self, pat: &Pat<'tcx>, mut it: impl FnMut(&Pat<'tcx>) -> bool) {
671+
self.walk_pat_inner(pat, &mut it);
666672
}
667673

668-
fn walk_(&self, it: &mut impl FnMut(&Pat<'tcx>) -> bool) {
669-
if !it(self) {
674+
fn walk_pat_inner(&self, pat: &Pat<'tcx>, it: &mut impl FnMut(&Pat<'tcx>) -> bool) {
675+
if !it(pat) {
670676
return;
671677
}
672678

673679
use PatKind::*;
674-
match &self.kind {
680+
match &pat.kind {
675681
Wild
676682
| Never
677683
| Range(..)
@@ -682,22 +688,24 @@ impl<'tcx> Pat<'tcx> {
682688
| Binding { subpattern: Some(subpattern), .. }
683689
| Deref { subpattern }
684690
| DerefPattern { subpattern, .. }
685-
| ExpandedConstant { subpattern, .. } => subpattern.walk_(it),
691+
| ExpandedConstant { subpattern, .. } => self.walk_pat_inner(subpattern, it),
686692
Leaf { subpatterns } | Variant { subpatterns, .. } => {
687-
subpatterns.iter().for_each(|field| field.pattern.walk_(it))
693+
subpatterns.iter().for_each(|field| self.walk_pat_inner(&field.pattern, it))
688694
}
689-
Or { pats } => pats.iter().for_each(|p| p.walk_(it)),
695+
Or { pats } => pats.iter().for_each(|p| self.walk_pat_inner(p, it)),
690696
Array { box ref prefix, ref slice, box ref suffix }
691-
| Slice { box ref prefix, ref slice, box ref suffix } => {
692-
prefix.iter().chain(slice.iter()).chain(suffix.iter()).for_each(|p| p.walk_(it))
693-
}
697+
| Slice { box ref prefix, ref slice, box ref suffix } => prefix
698+
.iter()
699+
.chain(slice.iter())
700+
.chain(suffix.iter())
701+
.for_each(|p| self.walk_pat_inner(p, it)),
694702
}
695703
}
696704

697705
/// Whether the pattern has a `PatKind::Error` nested within.
698-
pub fn pat_error_reported(&self) -> Result<(), ErrorGuaranteed> {
706+
pub fn pat_error_reported(&self, pat: &Pat<'tcx>) -> Result<(), ErrorGuaranteed> {
699707
let mut error = None;
700-
self.walk(|pat| {
708+
self.walk_pat(pat, |pat| {
701709
if let PatKind::Error(e) = pat.kind
702710
&& error.is_none()
703711
{
@@ -714,23 +722,23 @@ impl<'tcx> Pat<'tcx> {
714722
/// Walk the pattern in left-to-right order.
715723
///
716724
/// If you always want to recurse, prefer this method over `walk`.
717-
pub fn walk_always(&self, mut it: impl FnMut(&Pat<'tcx>)) {
718-
self.walk(|p| {
725+
pub fn walk_pat_always(&self, pat: &Pat<'tcx>, mut it: impl FnMut(&Pat<'tcx>)) {
726+
self.walk_pat(pat, |p| {
719727
it(p);
720728
true
721729
})
722730
}
723731

724732
/// Whether this a never pattern.
725-
pub fn is_never_pattern(&self) -> bool {
733+
pub fn is_never_pattern(&self, pat: &Pat<'tcx>) -> bool {
726734
let mut is_never_pattern = false;
727-
self.walk(|pat| match &pat.kind {
735+
self.walk_pat(pat, |pat| match &pat.kind {
728736
PatKind::Never => {
729737
is_never_pattern = true;
730738
false
731739
}
732740
PatKind::Or { pats } => {
733-
is_never_pattern = pats.iter().all(|p| p.is_never_pattern());
741+
is_never_pattern = pats.iter().all(|p| self.is_never_pattern(p));
734742
false
735743
}
736744
_ => true,

compiler/rustc_mir_build/src/builder/matches/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ impl<'tcx, 'pat> FlatPat<'pat, 'tcx> {
10121012
span: pattern.span,
10131013
bindings: Vec::new(),
10141014
ascriptions: Vec::new(),
1015-
is_never: pattern.is_never_pattern(),
1015+
is_never: cx.thir.is_never_pattern(pattern),
10161016
};
10171017
// Recursively remove irrefutable match pairs, while recording their
10181018
// bindings/ascriptions, and sort or-patterns after other match pairs.

compiler/rustc_mir_build/src/thir/pattern/check_match.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,14 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
278278
cx: &PatCtxt<'p, 'tcx>,
279279
pat: &'p Pat<'tcx>,
280280
) -> Result<&'p DeconstructedPat<'p, 'tcx>, ErrorGuaranteed> {
281-
if let Err(err) = pat.pat_error_reported() {
281+
if let Err(err) = cx.thir.pat_error_reported(pat) {
282282
self.error = Err(err);
283283
Err(err)
284284
} else {
285285
// Check the pattern for some things unrelated to exhaustiveness.
286286
let refutable = if cx.refutable { Refutable } else { Irrefutable };
287287
let mut err = Ok(());
288-
pat.walk_always(|pat| {
288+
cx.thir.walk_pat_always(pat, |pat| {
289289
check_borrow_conflicts_in_at_patterns(self, pat);
290290
check_for_bindings_named_same_as_variants(self, pat, refutable);
291291
err = err.and(check_never_pattern(cx, pat));
@@ -386,6 +386,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
386386
scrutinee.map(|scrut| self.is_known_valid_scrutinee(scrut)).unwrap_or(true);
387387
PatCtxt {
388388
tcx: self.tcx,
389+
thir: self.thir,
389390
typeck_results: self.typeck_results,
390391
typing_env: self.typing_env,
391392
module: self.tcx.parent_module(self.lint_level).to_def_id(),
@@ -714,7 +715,7 @@ impl<'p, 'tcx> MatchVisitor<'p, 'tcx> {
714715
&& scrut.is_some()
715716
{
716717
let mut bindings = vec![];
717-
pat.each_binding(|name, _, _, _| bindings.push(name));
718+
self.thir.each_pat_binding(pat, |name, _, _, _| bindings.push(name));
718719

719720
let semi_span = span.shrink_to_hi();
720721
let start_span = span.shrink_to_lo();
@@ -790,7 +791,7 @@ fn check_borrow_conflicts_in_at_patterns<'tcx>(cx: &MatchVisitor<'_, 'tcx>, pat:
790791
ByRef::No if is_binding_by_move(ty) => {
791792
// We have `x @ pat` where `x` is by-move. Reject all borrows in `pat`.
792793
let mut conflicts_ref = Vec::new();
793-
sub.each_binding(|_, mode, _, span| {
794+
cx.thir.each_pat_binding(sub, |_, mode, _, span| {
794795
if matches!(mode, ByRef::Yes(_)) {
795796
conflicts_ref.push(span)
796797
}
@@ -819,7 +820,7 @@ fn check_borrow_conflicts_in_at_patterns<'tcx>(cx: &MatchVisitor<'_, 'tcx>, pat:
819820
let mut conflicts_move = Vec::new();
820821
let mut conflicts_mut_mut = Vec::new();
821822
let mut conflicts_mut_ref = Vec::new();
822-
sub.each_binding(|name, mode, ty, span| {
823+
cx.thir.each_pat_binding(sub, |name, mode, ty, span| {
823824
match mode {
824825
ByRef::Yes(mut_inner) => match (mut_outer, mut_inner) {
825826
// Both sides are `ref`.

compiler/rustc_pattern_analysis/src/rustc.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use rustc_hir::def_id::DefId;
88
use rustc_index::{Idx, IndexVec};
99
use rustc_middle::middle::stability::EvalResult;
1010
use rustc_middle::mir::{self, Const};
11-
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary};
11+
use rustc_middle::thir::{self, Pat, PatKind, PatRange, PatRangeBoundary, Thir};
1212
use rustc_middle::ty::layout::IntegerExt;
1313
use rustc_middle::ty::{
1414
self, FieldDef, OpaqueTypeKey, ScalarInt, Ty, TyCtxt, TypeVisitableExt, VariantDef,
@@ -76,8 +76,9 @@ impl<'tcx> RevealedTy<'tcx> {
7676
}
7777

7878
#[derive(Clone)]
79-
pub struct RustcPatCtxt<'p, 'tcx: 'p> {
79+
pub struct RustcPatCtxt<'p, 'tcx> {
8080
pub tcx: TyCtxt<'tcx>,
81+
pub thir: &'p Thir<'tcx>,
8182
pub typeck_results: &'tcx ty::TypeckResults<'tcx>,
8283
/// The module in which the match occurs. This is necessary for
8384
/// checking inhabited-ness of types because whether a type is (visibly)

0 commit comments

Comments
 (0)