Skip to content

Commit 18a5514

Browse files
committed
Add identity match branch mir-optimization
1 parent 2359ecc commit 18a5514

File tree

3 files changed

+92
-0
lines changed

3 files changed

+92
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
use crate::transform::MirPass;
2+
use rustc_middle::mir::*;
3+
use rustc_middle::ty::TyCtxt;
4+
use rustc_target::abi::VariantIdx;
5+
6+
pub struct MatchIdentitySimplification;
7+
8+
impl<'tcx> MirPass<'tcx> for MatchIdentitySimplification {
9+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
10+
//let param_env = tcx.param_env(body.source.def_id());
11+
let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut();
12+
for bb_idx in bbs.indices() {
13+
let (read_discr, og_match) = match &bbs[bb_idx].statements[..] {
14+
&[Statement {
15+
kind: StatementKind::Assign(box (dst, Rvalue::Discriminant(src))),
16+
..
17+
}] => (dst, src),
18+
_ => continue,
19+
};
20+
let (var_idx, fst, snd) = match bbs[bb_idx].terminator().kind {
21+
TerminatorKind::SwitchInt {
22+
discr: Operand::Copy(ref place) | Operand::Move(ref place),
23+
ref targets,
24+
ref values,
25+
..
26+
} if targets.len() == 2
27+
&& values.len() == 1
28+
&& targets[0] != targets[1]
29+
// check that we're switching on the read discr
30+
&& place == &read_discr
31+
// check that this is actually
32+
&& place.ty(local_decls, tcx).ty.is_enum() =>
33+
{
34+
(VariantIdx::from(values[0] as usize), targets[0], targets[1])
35+
}
36+
// Only optimize switch int statements
37+
_ => continue,
38+
};
39+
let stmts_ok = |stmts: &[Statement<'_>], expected_variant| match stmts {
40+
[Statement {
41+
kind:
42+
StatementKind::Assign(box (
43+
dst0,
44+
Rvalue::Use(Operand::Copy(from) | Operand::Move(from)),
45+
)),
46+
..
47+
}, Statement {
48+
kind: StatementKind::SetDiscriminant { place: box dst1, variant_index },
49+
..
50+
}] => *variant_index == expected_variant && dst0 == dst1 && og_match == *from,
51+
_ => false,
52+
};
53+
let bb1 = &bbs[fst];
54+
let bb2 = &bbs[snd];
55+
if bb1.terminator().kind != bb2.terminator().kind
56+
|| stmts_ok(&bb1.statements[..], var_idx)
57+
|| stmts_ok(&bb2.statements[..], var_idx + 1)
58+
{
59+
continue;
60+
}
61+
let dst = match (&bb1.statements[0], &bb2.statements[0]) {
62+
(
63+
Statement { kind: StatementKind::Assign(box (dst0, _)), .. },
64+
Statement { kind: StatementKind::Assign(box (dst1, _)), .. },
65+
) if dst0 == dst1 => dst0.clone(),
66+
_ => continue,
67+
};
68+
let term_kind = bb1.terminator().kind.clone();
69+
// Reassign the output to just be the original
70+
// Replace the terminator with the terminator of the output
71+
bbs[bb_idx].statements[0].kind =
72+
StatementKind::Assign(box (dst, Rvalue::Use(Operand::Copy(og_match))));
73+
bbs[bb_idx].terminator_mut().kind = term_kind;
74+
}
75+
}
76+
}

compiler/rustc_mir/src/transform/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub mod inline;
3333
pub mod instcombine;
3434
pub mod instrument_coverage;
3535
pub mod match_branches;
36+
pub mod match_identity;
3637
pub mod multiple_return_terminators;
3738
pub mod no_landing_pads;
3839
pub mod nrvo;

src/test/mir-opt/match_identity.rs

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// EMIT_MIR match_identity.id_result.match_identity.diff
2+
pub fn id_result(a: Result<u64, i64>) -> Result<u64, i64> {
3+
match a {
4+
Ok(x) => Ok(x),
5+
Err(y) => Err(y),
6+
}
7+
}
8+
9+
// EMIT_MIR match_identity.id_result.match_identity.diff
10+
pub fn flip_flop(a: Result<u64, i64>) -> Result<i64, u64> {
11+
match a {
12+
Ok(x) => Err(x),
13+
Err(y) => Ok(y),
14+
}
15+
}

0 commit comments

Comments
 (0)