Skip to content

Commit 3d3fbc3

Browse files
committed
Auto merge of rust-lang#123259 - scottmcm:tweak-if-const, r=<try>
Fixup `if T::CONST` in MIR r? ghost
2 parents 5f358a8 + 2fb9d27 commit 3d3fbc3

13 files changed

+467
-350
lines changed

compiler/rustc_codegen_ssa/src/mir/block.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,15 +361,24 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
361361
discr: &mir::Operand<'tcx>,
362362
targets: &SwitchTargets,
363363
) {
364-
let discr = self.codegen_operand(bx, discr);
365-
let discr_value = discr.immediate();
366-
let switch_ty = discr.layout.ty;
367364
// If our discriminant is a constant we can branch directly
368-
if let Some(const_discr) = bx.const_to_opt_u128(discr_value, false) {
365+
if let Some(const_op) = discr.constant() {
366+
let const_value = self.eval_mir_constant(const_op);
367+
let Some(const_discr) = const_value.try_to_bits_for_ty(
368+
self.cx.tcx(),
369+
ty::ParamEnv::reveal_all(),
370+
const_op.ty(),
371+
) else {
372+
bug!("Failed to evaluate constant {discr:?} for SwitchInt terminator")
373+
};
369374
let target = targets.target_for_value(const_discr);
370375
bx.br(helper.llbb_with_cleanup(self, target));
371376
return;
372-
};
377+
}
378+
379+
let discr = self.codegen_operand(bx, discr);
380+
let discr_value = discr.immediate();
381+
let switch_ty = discr.layout.ty;
373382

374383
let mut target_iter = targets.iter();
375384
if target_iter.len() == 1 {

compiler/rustc_middle/src/mir/mod.rs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,7 @@ impl<'tcx> Body<'tcx> {
765765
};
766766

767767
// If this is a SwitchInt(const _), then we can just evaluate the constant and return.
768+
// (The `SwitchConst` transform pass tries to ensure this.)
768769
let discr = match discr {
769770
Operand::Constant(constant) => {
770771
let bits = eval_mono_const(constant);
@@ -773,24 +774,18 @@ impl<'tcx> Body<'tcx> {
773774
Operand::Move(place) | Operand::Copy(place) => place,
774775
};
775776

776-
// MIR for `if false` actually looks like this:
777-
// _1 = const _
778-
// SwitchInt(_1)
779-
//
780777
// And MIR for if intrinsics::debug_assertions() looks like this:
781778
// _1 = cfg!(debug_assertions)
782779
// SwitchInt(_1)
783780
//
784781
// So we're going to try to recognize this pattern.
785782
//
786-
// If we have a SwitchInt on a non-const place, we find the most recent statement that
787-
// isn't a storage marker. If that statement is an assignment of a const to our
788-
// discriminant place, we evaluate and return the const, as if we've const-propagated it
789-
// into the SwitchInt.
783+
// If we have a SwitchInt on a non-const place, we look at the last statement
784+
// in the block. If that statement is an assignment of UbChecks to our
785+
// discriminant place, we evaluate its value, as if we've
786+
// const-propagated it into the SwitchInt.
790787

791-
let last_stmt = block.statements.iter().rev().find(|stmt| {
792-
!matches!(stmt.kind, StatementKind::StorageDead(_) | StatementKind::StorageLive(_))
793-
})?;
788+
let last_stmt = block.statements.last()?;
794789

795790
let (place, rvalue) = last_stmt.kind.as_assign()?;
796791

@@ -802,10 +797,6 @@ impl<'tcx> Body<'tcx> {
802797
Rvalue::NullaryOp(NullOp::UbChecks, _) => {
803798
Some((tcx.sess.opts.debug_assertions as u128, targets))
804799
}
805-
Rvalue::Use(Operand::Constant(constant)) => {
806-
let bits = eval_mono_const(constant);
807-
Some((bits, targets))
808-
}
809800
_ => None,
810801
}
811802
}

compiler/rustc_mir_transform/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ pub mod simplify;
110110
mod simplify_branches;
111111
mod simplify_comparison_integral;
112112
mod sroa;
113+
mod switch_const;
113114
mod uninhabited_enum_branching;
114115
mod unreachable_prop;
115116

@@ -600,6 +601,8 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
600601
&simplify::SimplifyLocals::AfterGVN,
601602
&dataflow_const_prop::DataflowConstProp,
602603
&const_debuginfo::ConstDebugInfo,
604+
// GVN & ConstProp often don't fixup unevaluatable constants
605+
&switch_const::SwitchConst,
603606
&o1(simplify_branches::SimplifyConstCondition::AfterConstProp),
604607
&jump_threading::JumpThreading,
605608
&early_otherwise_branch::EarlyOtherwiseBranch,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//! A pass that makes `SwitchInt`-on-`const` more obvious to later code.
2+
3+
use rustc_middle::mir::*;
4+
use rustc_middle::ty::TyCtxt;
5+
6+
/// A `MirPass` for simplifying `if T::CONST`.
7+
///
8+
/// Today, MIR building for things like `if T::IS_ZST` introduce a constant
9+
/// for the copy of the bool, so it ends up in MIR as
10+
/// `_1 = CONST; switchInt (move _1)` or `_2 = CONST; switchInt (_2)`.
11+
///
12+
/// This pass is very specifically targeted at *exactly* those patterns.
13+
/// It can absolutely be replaced with a more general pass should we get one that
14+
/// we can run in low optimization levels, but at the time of writing even in
15+
/// optimized builds this wasn't simplified.
16+
#[derive(Default)]
17+
pub struct SwitchConst;
18+
19+
impl<'tcx> MirPass<'tcx> for SwitchConst {
20+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
21+
for block in body.basic_blocks.as_mut_preserves_cfg() {
22+
let switch_local = if let TerminatorKind::SwitchInt { discr, .. } =
23+
&block.terminator().kind
24+
&& let Some(place) = discr.place()
25+
&& let Some(local) = place.as_local()
26+
{
27+
local
28+
} else {
29+
continue;
30+
};
31+
32+
let new_operand = if let Some(statement) = block.statements.last()
33+
&& let StatementKind::Assign(place_and_rvalue) = &statement.kind
34+
&& let Some(local) = place_and_rvalue.0.as_local()
35+
&& local == switch_local
36+
&& let Rvalue::Use(operand) = &place_and_rvalue.1
37+
&& let Operand::Constant(_) = operand
38+
{
39+
operand.clone()
40+
} else {
41+
continue;
42+
};
43+
44+
if !tcx.consider_optimizing(|| format!("SwitchConst: switchInt(move {switch_local:?}"))
45+
{
46+
break;
47+
}
48+
49+
let TerminatorKind::SwitchInt { discr, .. } = &mut block.terminator_mut().kind else {
50+
bug!("Somehow wasn't a switchInt any more?")
51+
};
52+
*discr = new_operand;
53+
}
54+
}
55+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// MIR for `check_bool` after PreCodegen
2+
3+
fn check_bool() -> u32 {
4+
let mut _0: u32;
5+
6+
bb0: {
7+
switchInt(const <T as TraitWithBool>::FLAG) -> [0: bb1, otherwise: bb2];
8+
}
9+
10+
bb1: {
11+
_0 = const 456_u32;
12+
goto -> bb3;
13+
}
14+
15+
bb2: {
16+
_0 = const 123_u32;
17+
goto -> bb3;
18+
}
19+
20+
bb3: {
21+
return;
22+
}
23+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// MIR for `check_int` after PreCodegen
2+
3+
fn check_int() -> u32 {
4+
let mut _0: u32;
5+
6+
bb0: {
7+
switchInt(const <T as TraitWithInt>::VALUE) -> [1: bb1, 2: bb2, 3: bb3, otherwise: bb4];
8+
}
9+
10+
bb1: {
11+
_0 = const 123_u32;
12+
goto -> bb5;
13+
}
14+
15+
bb2: {
16+
_0 = const 456_u32;
17+
goto -> bb5;
18+
}
19+
20+
bb3: {
21+
_0 = const 789_u32;
22+
goto -> bb5;
23+
}
24+
25+
bb4: {
26+
_0 = const 0_u32;
27+
goto -> bb5;
28+
}
29+
30+
bb5: {
31+
return;
32+
}
33+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// skip-filecheck
2+
//@ compile-flags: -O -Zmir-opt-level=2 -Cdebuginfo=2
3+
4+
#![crate_type = "lib"]
5+
6+
pub trait TraitWithBool {
7+
const FLAG: bool;
8+
}
9+
10+
// EMIT_MIR if_associated_const.check_bool.PreCodegen.after.mir
11+
pub fn check_bool<T: TraitWithBool>() -> u32 {
12+
if T::FLAG { 123 } else { 456 }
13+
}
14+
15+
pub trait TraitWithInt {
16+
const VALUE: i32;
17+
}
18+
19+
// EMIT_MIR if_associated_const.check_int.PreCodegen.after.mir
20+
pub fn check_int<T: TraitWithInt>() -> u32 {
21+
match T::VALUE {
22+
1 => 123,
23+
2 => 456,
24+
3 => 789,
25+
_ => 0,
26+
}
27+
}

0 commit comments

Comments
 (0)