Skip to content

Commit 0c9b5e2

Browse files
committed
Use an enum instead of manually tracking indices for target_blocks
1 parent c2564d0 commit 0c9b5e2

5 files changed

+112
-77
lines changed

compiler/rustc_mir_build/src/build/matches/mod.rs

+25-9
Original file line numberDiff line numberDiff line change
@@ -1182,6 +1182,19 @@ pub(crate) struct Test<'tcx> {
11821182
kind: TestKind<'tcx>,
11831183
}
11841184

1185+
/// The branch to be taken after a test.
1186+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1187+
enum TestBranch<'tcx> {
1188+
/// Success branch, used for tests with two possible outcomes.
1189+
Success,
1190+
/// Branch corresponding to this constant.
1191+
Constant(Const<'tcx>, u128),
1192+
/// Branch corresponding to this variant.
1193+
Variant(VariantIdx),
1194+
/// Failure branch for tests with two possible outcomes, and "otherwise" branch for other tests.
1195+
Failure,
1196+
}
1197+
11851198
/// `ArmHasGuard` is a wrapper around a boolean flag. It indicates whether
11861199
/// a match arm has a guard expression attached to it.
11871200
#[derive(Copy, Clone, Debug)]
@@ -1659,23 +1672,26 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
16591672
match_place: &PlaceBuilder<'tcx>,
16601673
test: &Test<'tcx>,
16611674
mut candidates: &'b mut [&'c mut Candidate<'pat, 'tcx>],
1662-
) -> (&'b mut [&'c mut Candidate<'pat, 'tcx>], Vec<Vec<&'b mut Candidate<'pat, 'tcx>>>) {
1675+
) -> (
1676+
&'b mut [&'c mut Candidate<'pat, 'tcx>],
1677+
FxIndexMap<TestBranch<'tcx>, Vec<&'b mut Candidate<'pat, 'tcx>>>,
1678+
) {
16631679
// For each of the N possible outcomes, create a (initially empty) vector of candidates.
16641680
// Those are the candidates that apply if the test has that particular outcome.
1665-
let mut target_candidates: Vec<Vec<&mut Candidate<'pat, 'tcx>>> = vec![];
1666-
target_candidates.resize_with(test.targets(), Default::default);
1681+
let mut target_candidates: FxIndexMap<_, Vec<&mut Candidate<'pat, 'tcx>>> =
1682+
test.targets().into_iter().map(|branch| (branch, Vec::new())).collect();
16671683

16681684
let total_candidate_count = candidates.len();
16691685

16701686
// Sort the candidates into the appropriate vector in `target_candidates`. Note that at some
16711687
// point we may encounter a candidate where the test is not relevant; at that point, we stop
16721688
// sorting.
16731689
while let Some(candidate) = candidates.first_mut() {
1674-
let Some(idx) = self.sort_candidate(&match_place, &test, candidate) else {
1690+
let Some(branch) = self.sort_candidate(&match_place, &test, candidate) else {
16751691
break;
16761692
};
16771693
let (candidate, rest) = candidates.split_first_mut().unwrap();
1678-
target_candidates[idx].push(candidate);
1694+
target_candidates[&branch].push(candidate);
16791695
candidates = rest;
16801696
}
16811697

@@ -1820,9 +1836,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
18201836
// apply. Collect a list of blocks where control flow will
18211837
// branch if one of the `target_candidate` sets is not
18221838
// exhaustive.
1823-
let target_blocks: Vec<_> = target_candidates
1839+
let target_blocks: FxIndexMap<_, _> = target_candidates
18241840
.into_iter()
1825-
.map(|mut candidates| {
1841+
.map(|(branch, mut candidates)| {
18261842
if !candidates.is_empty() {
18271843
let candidate_start = self.cfg.start_new_block();
18281844
self.match_candidates(
@@ -1832,9 +1848,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
18321848
remainder_start,
18331849
&mut *candidates,
18341850
);
1835-
candidate_start
1851+
(branch, candidate_start)
18361852
} else {
1837-
remainder_start
1853+
(branch, remainder_start)
18381854
}
18391855
})
18401856
.collect();

compiler/rustc_mir_build/src/build/matches/test.rs

+67-50
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
// the candidates based on the result.
77

88
use crate::build::expr::as_place::PlaceBuilder;
9-
use crate::build::matches::{Candidate, MatchPair, Test, TestCase, TestKind};
9+
use crate::build::matches::{Candidate, MatchPair, Test, TestBranch, TestCase, TestKind};
1010
use crate::build::Builder;
1111
use rustc_data_structures::fx::FxIndexMap;
1212
use rustc_hir::{LangItem, RangeEnd};
@@ -129,32 +129,33 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
129129
block: BasicBlock,
130130
place_builder: &PlaceBuilder<'tcx>,
131131
test: &Test<'tcx>,
132-
target_blocks: Vec<BasicBlock>,
132+
target_blocks: FxIndexMap<TestBranch<'tcx>, BasicBlock>,
133133
) {
134134
let place = place_builder.to_place(self);
135135
let place_ty = place.ty(&self.local_decls, self.tcx);
136-
debug!(?place, ?place_ty,);
136+
debug!(?place, ?place_ty);
137+
let target_block = |branch| target_blocks[&branch];
137138

138139
let source_info = self.source_info(test.span);
139140
match test.kind {
140141
TestKind::Switch { adt_def, ref variants } => {
141142
// Variants is a BitVec of indexes into adt_def.variants.
142143
let num_enum_variants = adt_def.variants().len();
143144
debug_assert_eq!(target_blocks.len(), num_enum_variants + 1);
144-
let otherwise_block = *target_blocks.last().unwrap();
145+
let otherwise_block = target_block(TestBranch::Failure);
145146
let tcx = self.tcx;
146147
let switch_targets = SwitchTargets::new(
147148
adt_def.discriminants(tcx).filter_map(|(idx, discr)| {
148149
if variants.contains(idx) {
149150
debug_assert_ne!(
150-
target_blocks[idx.index()],
151+
target_block(TestBranch::Variant(idx)),
151152
otherwise_block,
152153
"no candidates for tested discriminant: {discr:?}",
153154
);
154-
Some((discr.val, target_blocks[idx.index()]))
155+
Some((discr.val, target_block(TestBranch::Variant(idx))))
155156
} else {
156157
debug_assert_eq!(
157-
target_blocks[idx.index()],
158+
target_block(TestBranch::Variant(idx)),
158159
otherwise_block,
159160
"found candidates for untested discriminant: {discr:?}",
160161
);
@@ -185,9 +186,11 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
185186
TestKind::SwitchInt { ref options } => {
186187
// The switch may be inexhaustive so we have a catch-all block
187188
debug_assert_eq!(options.len() + 1, target_blocks.len());
188-
let otherwise_block = *target_blocks.last().unwrap();
189+
let otherwise_block = target_block(TestBranch::Failure);
189190
let switch_targets = SwitchTargets::new(
190-
options.values().copied().zip(target_blocks),
191+
options
192+
.iter()
193+
.map(|(&val, &bits)| (bits, target_block(TestBranch::Constant(val, bits)))),
191194
otherwise_block,
192195
);
193196
let terminator = TerminatorKind::SwitchInt {
@@ -198,18 +201,19 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
198201
}
199202

200203
TestKind::If => {
201-
let [false_bb, true_bb] = *target_blocks else {
202-
bug!("`TestKind::If` should have two targets")
203-
};
204-
let terminator = TerminatorKind::if_(Operand::Copy(place), true_bb, false_bb);
204+
debug_assert_eq!(target_blocks.len(), 2);
205+
let success_block = target_block(TestBranch::Success);
206+
let fail_block = target_block(TestBranch::Failure);
207+
let terminator =
208+
TerminatorKind::if_(Operand::Copy(place), success_block, fail_block);
205209
self.cfg.terminate(block, self.source_info(match_start_span), terminator);
206210
}
207211

208212
TestKind::Eq { value, ty } => {
209213
let tcx = self.tcx;
210-
let [success_block, fail_block] = *target_blocks else {
211-
bug!("`TestKind::Eq` should have two target blocks")
212-
};
214+
debug_assert_eq!(target_blocks.len(), 2);
215+
let success_block = target_block(TestBranch::Success);
216+
let fail_block = target_block(TestBranch::Failure);
213217
if let ty::Adt(def, _) = ty.kind()
214218
&& Some(def.did()) == tcx.lang_items().string()
215219
{
@@ -286,9 +290,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
286290
}
287291

288292
TestKind::Range(ref range) => {
289-
let [success, fail] = *target_blocks else {
290-
bug!("`TestKind::Range` should have two target blocks");
291-
};
293+
debug_assert_eq!(target_blocks.len(), 2);
294+
let success = target_block(TestBranch::Success);
295+
let fail = target_block(TestBranch::Failure);
292296
// Test `val` by computing `lo <= val && val <= hi`, using primitive comparisons.
293297
let val = Operand::Copy(place);
294298

@@ -333,15 +337,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
333337
// expected = <N>
334338
let expected = self.push_usize(block, source_info, len);
335339

336-
let [true_bb, false_bb] = *target_blocks else {
337-
bug!("`TestKind::Len` should have two target blocks");
338-
};
340+
debug_assert_eq!(target_blocks.len(), 2);
341+
let success_block = target_block(TestBranch::Success);
342+
let fail_block = target_block(TestBranch::Failure);
339343
// result = actual == expected OR result = actual < expected
340344
// branch based on result
341345
self.compare(
342346
block,
343-
true_bb,
344-
false_bb,
347+
success_block,
348+
fail_block,
345349
source_info,
346350
op,
347351
Operand::Move(actual),
@@ -526,10 +530,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
526530

527531
/// Given that we are performing `test` against `test_place`, this job
528532
/// sorts out what the status of `candidate` will be after the test. See
529-
/// `test_candidates` for the usage of this function. The returned index is
530-
/// the index that this candidate should be placed in the
531-
/// `target_candidates` vec. The candidate may be modified to update its
532-
/// `match_pairs`.
533+
/// `test_candidates` for the usage of this function. The candidate may
534+
/// be modified to update its `match_pairs`.
533535
///
534536
/// So, for example, if this candidate is `x @ Some(P0)` and the `Test` is
535537
/// a variant test, then we would modify the candidate to be `(x as
@@ -556,7 +558,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
556558
test_place: &PlaceBuilder<'tcx>,
557559
test: &Test<'tcx>,
558560
candidate: &mut Candidate<'pat, 'tcx>,
559-
) -> Option<usize> {
561+
) -> Option<TestBranch<'tcx>> {
560562
// Find the match_pair for this place (if any). At present,
561563
// afaik, there can be at most one. (In the future, if we
562564
// adopted a more general `@` operator, there might be more
@@ -576,20 +578,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
576578
) => {
577579
assert_eq!(adt_def, tested_adt_def);
578580
fully_matched = true;
579-
Some(variant_index.as_usize())
581+
Some(TestBranch::Variant(variant_index))
580582
}
581583

582584
// If we are performing a switch over integers, then this informs integer
583585
// equality, but nothing else.
584586
//
585587
// FIXME(#29623) we could use PatKind::Range to rule
586588
// things out here, in some cases.
587-
(TestKind::SwitchInt { options }, TestCase::Constant { value })
589+
(TestKind::SwitchInt { options }, &TestCase::Constant { value })
588590
if is_switch_ty(match_pair.pattern.ty) =>
589591
{
590592
fully_matched = true;
591-
let index = options.get_index_of(value).unwrap();
592-
Some(index)
593+
let bits = options.get(&value).unwrap();
594+
Some(TestBranch::Constant(value, *bits))
593595
}
594596
(TestKind::SwitchInt { options }, TestCase::Range(range)) => {
595597
fully_matched = false;
@@ -599,7 +601,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
599601
not_contained.then(|| {
600602
// No switch values are contained in the pattern range,
601603
// so the pattern can be matched only if this test fails.
602-
options.len()
604+
TestBranch::Failure
603605
})
604606
}
605607

@@ -608,7 +610,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
608610
let value = value.try_eval_bool(self.tcx, self.param_env).unwrap_or_else(|| {
609611
span_bug!(test.span, "expected boolean value but got {value:?}")
610612
});
611-
Some(value as usize)
613+
Some(if value { TestBranch::Success } else { TestBranch::Failure })
612614
}
613615

614616
(
@@ -620,14 +622,14 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
620622
// on true, min_len = len = $actual_length,
621623
// on false, len != $actual_length
622624
fully_matched = true;
623-
Some(0)
625+
Some(TestBranch::Success)
624626
}
625627
(Ordering::Less, _) => {
626628
// test_len < pat_len. If $actual_len = test_len,
627629
// then $actual_len < pat_len and we don't have
628630
// enough elements.
629631
fully_matched = false;
630-
Some(1)
632+
Some(TestBranch::Failure)
631633
}
632634
(Ordering::Equal | Ordering::Greater, true) => {
633635
// This can match both if $actual_len = test_len >= pat_len,
@@ -639,7 +641,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
639641
// test_len != pat_len, so if $actual_len = test_len, then
640642
// $actual_len != pat_len.
641643
fully_matched = false;
642-
Some(1)
644+
Some(TestBranch::Failure)
643645
}
644646
}
645647
}
@@ -653,20 +655,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
653655
// $actual_len >= test_len = pat_len,
654656
// so we can match.
655657
fully_matched = true;
656-
Some(0)
658+
Some(TestBranch::Success)
657659
}
658660
(Ordering::Less, _) | (Ordering::Equal, false) => {
659661
// test_len <= pat_len. If $actual_len < test_len,
660662
// then it is also < pat_len, so the test passing is
661663
// necessary (but insufficient).
662664
fully_matched = false;
663-
Some(0)
665+
Some(TestBranch::Success)
664666
}
665667
(Ordering::Greater, false) => {
666668
// test_len > pat_len. If $actual_len >= test_len > pat_len,
667669
// then we know we won't have a match.
668670
fully_matched = false;
669-
Some(1)
671+
Some(TestBranch::Failure)
670672
}
671673
(Ordering::Greater, true) => {
672674
// test_len < pat_len, and is therefore less
@@ -680,20 +682,24 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
680682
(TestKind::Range(test), &TestCase::Range(pat)) => {
681683
if test.as_ref() == pat {
682684
fully_matched = true;
683-
Some(0)
685+
Some(TestBranch::Success)
684686
} else {
685687
fully_matched = false;
686688
// If the testing range does not overlap with pattern range,
687689
// the pattern can be matched only if this test fails.
688-
if !test.overlaps(pat, self.tcx, self.param_env)? { Some(1) } else { None }
690+
if !test.overlaps(pat, self.tcx, self.param_env)? {
691+
Some(TestBranch::Failure)
692+
} else {
693+
None
694+
}
689695
}
690696
}
691697
(TestKind::Range(range), &TestCase::Constant { value }) => {
692698
fully_matched = false;
693699
if !range.contains(value, self.tcx, self.param_env)? {
694700
// `value` is not contained in the testing range,
695701
// so `value` can be matched only if this test fails.
696-
Some(1)
702+
Some(TestBranch::Failure)
697703
} else {
698704
None
699705
}
@@ -704,7 +710,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
704710
if test_val == case_val =>
705711
{
706712
fully_matched = true;
707-
Some(0)
713+
Some(TestBranch::Success)
708714
}
709715

710716
(
@@ -747,18 +753,29 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
747753
}
748754
}
749755

750-
impl Test<'_> {
751-
pub(super) fn targets(&self) -> usize {
756+
impl<'tcx> Test<'tcx> {
757+
pub(super) fn targets(&self) -> Vec<TestBranch<'tcx>> {
752758
match self.kind {
753-
TestKind::Eq { .. } | TestKind::Range(_) | TestKind::Len { .. } | TestKind::If => 2,
759+
TestKind::Eq { .. } | TestKind::Range(_) | TestKind::Len { .. } | TestKind::If => {
760+
vec![TestBranch::Success, TestBranch::Failure]
761+
}
754762
TestKind::Switch { adt_def, .. } => {
755763
// While the switch that we generate doesn't test for all
756764
// variants, we have a target for each variant and the
757765
// otherwise case, and we make sure that all of the cases not
758766
// specified have the same block.
759-
adt_def.variants().len() + 1
767+
adt_def
768+
.variants()
769+
.indices()
770+
.map(|idx| TestBranch::Variant(idx))
771+
.chain([TestBranch::Failure])
772+
.collect()
760773
}
761-
TestKind::SwitchInt { ref options } => options.len() + 1,
774+
TestKind::SwitchInt { ref options } => options
775+
.iter()
776+
.map(|(val, bits)| TestBranch::Constant(*val, *bits))
777+
.chain([TestBranch::Failure])
778+
.collect(),
762779
}
763780
}
764781
}

0 commit comments

Comments
 (0)