|
| 1 | +use crate::transform::MirPass; |
| 2 | +use rustc_middle::mir::*; |
| 3 | +use rustc_middle::ty::TyCtxt; |
| 4 | +use rustc_target::abi::VariantIdx; |
| 5 | + |
| 6 | +pub struct MatchIdentitySimplification; |
| 7 | + |
| 8 | +impl<'tcx> MirPass<'tcx> for MatchIdentitySimplification { |
| 9 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
| 10 | + //let param_env = tcx.param_env(body.source.def_id()); |
| 11 | + let (bbs, local_decls) = body.basic_blocks_and_local_decls_mut(); |
| 12 | + for bb_idx in bbs.indices() { |
| 13 | + let (read_discr, og_match) = match &bbs[bb_idx].statements[..] { |
| 14 | + &[Statement { |
| 15 | + kind: StatementKind::Assign(box (dst, Rvalue::Discriminant(src))), |
| 16 | + .. |
| 17 | + }] => (dst, src), |
| 18 | + _ => continue, |
| 19 | + }; |
| 20 | + let (var_idx, fst, snd) = match bbs[bb_idx].terminator().kind { |
| 21 | + TerminatorKind::SwitchInt { |
| 22 | + discr: Operand::Copy(ref place) | Operand::Move(ref place), |
| 23 | + ref targets, |
| 24 | + ref values, |
| 25 | + .. |
| 26 | + } if targets.len() == 2 |
| 27 | + && values.len() == 1 |
| 28 | + && targets[0] != targets[1] |
| 29 | + // check that we're switching on the read discr |
| 30 | + && place == &read_discr |
| 31 | + // check that this is actually |
| 32 | + && place.ty(local_decls, tcx).ty.is_enum() => |
| 33 | + { |
| 34 | + (VariantIdx::from(values[0] as usize), targets[0], targets[1]) |
| 35 | + } |
| 36 | + // Only optimize switch int statements |
| 37 | + _ => continue, |
| 38 | + }; |
| 39 | + let stmts_ok = |stmts: &[Statement<'_>], expected_variant| match stmts { |
| 40 | + [Statement { |
| 41 | + kind: |
| 42 | + StatementKind::Assign(box ( |
| 43 | + dst0, |
| 44 | + Rvalue::Use(Operand::Copy(from) | Operand::Move(from)), |
| 45 | + )), |
| 46 | + .. |
| 47 | + }, Statement { |
| 48 | + kind: StatementKind::SetDiscriminant { place: box dst1, variant_index }, |
| 49 | + .. |
| 50 | + }] => *variant_index == expected_variant && dst0 == dst1 && og_match == *from, |
| 51 | + _ => false, |
| 52 | + }; |
| 53 | + let bb1 = &bbs[fst]; |
| 54 | + let bb2 = &bbs[snd]; |
| 55 | + if bb1.terminator().kind != bb2.terminator().kind |
| 56 | + || stmts_ok(&bb1.statements[..], var_idx) |
| 57 | + || stmts_ok(&bb2.statements[..], var_idx + 1) |
| 58 | + { |
| 59 | + continue; |
| 60 | + } |
| 61 | + let dst = match (&bb1.statements[0], &bb2.statements[0]) { |
| 62 | + ( |
| 63 | + Statement { kind: StatementKind::Assign(box (dst0, _)), .. }, |
| 64 | + Statement { kind: StatementKind::Assign(box (dst1, _)), .. }, |
| 65 | + ) if dst0 == dst1 => dst0.clone(), |
| 66 | + _ => continue, |
| 67 | + }; |
| 68 | + let term_kind = bb1.terminator().kind.clone(); |
| 69 | + // Reassign the output to just be the original |
| 70 | + // Replace the terminator with the terminator of the output |
| 71 | + bbs[bb_idx].statements[0].kind = |
| 72 | + StatementKind::Assign(box (dst, Rvalue::Use(Operand::Copy(og_match)))); |
| 73 | + bbs[bb_idx].terminator_mut().kind = term_kind; |
| 74 | + } |
| 75 | + } |
| 76 | +} |
0 commit comments