Skip to content

Commit 5f75825

Browse files
committed
Use a ConstValue instead.
1 parent e49e52b commit 5f75825

20 files changed

+688
-271
lines changed

compiler/rustc_mir_transform/src/dataflow_const_prop.rs

+139-96
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,21 @@
33
//! Currently, this pass only propagates scalar values.
44
55
use rustc_const_eval::const_eval::CheckAlignment;
6-
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, Projectable};
6+
use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable};
77
use rustc_data_structures::fx::FxHashMap;
88
use rustc_hir::def::DefKind;
99
use rustc_middle::mir::interpret::{AllocId, ConstAllocation, InterpResult, Scalar};
1010
use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor};
1111
use rustc_middle::mir::*;
12-
use rustc_middle::ty::layout::TyAndLayout;
12+
use rustc_middle::ty::layout::{LayoutOf, TyAndLayout};
1313
use rustc_middle::ty::{self, Ty, TyCtxt};
1414
use rustc_mir_dataflow::value_analysis::{
1515
Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace,
1616
};
1717
use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor};
1818
use rustc_span::def_id::DefId;
1919
use rustc_span::DUMMY_SP;
20-
use rustc_target::abi::{Align, FieldIdx, VariantIdx};
20+
use rustc_target::abi::{Align, FieldIdx, Size, VariantIdx, FIRST_VARIANT};
2121

2222
use crate::MirPass;
2323

@@ -546,110 +546,134 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
546546

547547
fn try_make_constant(
548548
&self,
549+
ecx: &mut InterpCx<'tcx, 'tcx, DummyMachine>,
549550
place: Place<'tcx>,
550551
state: &State<FlatSet<Scalar>>,
551552
map: &Map,
552553
) -> Option<ConstantKind<'tcx>> {
553554
let ty = place.ty(self.local_decls, self.patch.tcx).ty;
554555
let place = map.find(place.as_ref())?;
555-
if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) {
556-
Some(ConstantKind::Val(ConstValue::Scalar(value.into()), ty))
556+
let layout = ecx.layout_of(ty).ok()?;
557+
if layout.is_zst() {
558+
Some(ConstantKind::Val(ConstValue::ZeroSized, ty))
559+
} else if layout.abi.is_scalar()
560+
&& let Some(value) = propagatable_scalar(place, state, map)
561+
{
562+
Some(ConstantKind::Val(ConstValue::Scalar(value), ty))
563+
} else if layout.is_sized() && layout.size <= 4 * ecx.tcx.data_layout.pointer_size {
564+
let alloc_id = ecx
565+
.intern_with_temp_alloc(layout, |ecx, dest| {
566+
try_write_constant(ecx, dest, place, ty, state, map)
567+
})
568+
.ok()?;
569+
Some(ConstantKind::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty))
557570
} else {
558-
let valtree = self.try_make_valtree(place, ty, state, map)?;
559-
let constant = ty::Const::new_value(self.patch.tcx, valtree, ty);
560-
Some(ConstantKind::Ty(constant))
571+
None
561572
}
562573
}
574+
}
563575

564-
fn try_make_valtree(
565-
&self,
566-
place: PlaceIndex,
567-
ty: Ty<'tcx>,
568-
state: &State<FlatSet<Scalar>>,
569-
map: &Map,
570-
) -> Option<ty::ValTree<'tcx>> {
571-
let tcx = self.patch.tcx;
572-
match ty.kind() {
573-
// ZSTs.
574-
ty::FnDef(..) => Some(ty::ValTree::zst()),
575-
576-
// Scalars.
577-
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => {
578-
if let FlatSet::Elem(Scalar::Int(value)) = state.get_idx(place, map) {
579-
Some(ty::ValTree::Leaf(value))
580-
} else {
581-
None
582-
}
583-
}
576+
fn propagatable_scalar(
577+
place: PlaceIndex,
578+
state: &State<FlatSet<Scalar>>,
579+
map: &Map,
580+
) -> Option<Scalar> {
581+
if let FlatSet::Elem(value) = state.get_idx(place, map) && value.try_to_int().is_ok() {
582+
// Do not attempt to propagate pointers, as we may fail to preserve their identity.
583+
Some(value)
584+
} else {
585+
None
586+
}
587+
}
588+
589+
#[instrument(level = "trace", skip(ecx, state, map))]
590+
fn try_write_constant<'tcx>(
591+
ecx: &mut InterpCx<'_, 'tcx, DummyMachine>,
592+
dest: &PlaceTy<'tcx>,
593+
place: PlaceIndex,
594+
ty: Ty<'tcx>,
595+
state: &State<FlatSet<Scalar>>,
596+
map: &Map,
597+
) -> InterpResult<'tcx> {
598+
let layout = ecx.layout_of(ty)?;
599+
600+
// Fast path for ZSTs.
601+
if layout.is_zst() {
602+
return Ok(());
603+
}
604+
605+
// Fast path for scalars.
606+
if layout.abi.is_scalar()
607+
&& let Some(value) = propagatable_scalar(place, state, map)
608+
{
609+
return ecx.write_immediate(Immediate::Scalar(value), dest);
610+
}
584611

585-
// Unsupported for now.
586-
ty::Array(_, _) => None,
587-
588-
ty::Tuple(elem_tys) => {
589-
let branches = elem_tys
590-
.iter()
591-
.enumerate()
592-
.map(|(i, ty)| {
593-
let field = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i)))?;
594-
self.try_make_valtree(field, ty, state, map)
595-
})
596-
.collect::<Option<Vec<_>>>()?;
597-
Some(ty::ValTree::Branch(tcx.arena.alloc_from_iter(branches.into_iter())))
612+
match ty.kind() {
613+
// ZSTs. Nothing to do.
614+
ty::FnDef(..) => {}
615+
616+
// Those are scalars, must be handled above.
617+
ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_inval!(ConstPropNonsense),
618+
619+
ty::Tuple(elem_tys) => {
620+
for (i, elem) in elem_tys.iter().enumerate() {
621+
let field = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))).ok_or(err_inval!(ConstPropNonsense))?;
622+
let field_dest = ecx.project_field(dest, i)?;
623+
try_write_constant(ecx, &field_dest, field, elem, state, map)?;
598624
}
625+
}
599626

600-
ty::Adt(def, args) => {
601-
if def.is_union() {
602-
return None;
603-
}
627+
ty::Adt(def, args) => {
628+
if def.is_union() {
629+
throw_inval!(ConstPropNonsense)
630+
}
604631

605-
let (variant_idx, variant_def, variant_place) = if def.is_enum() {
606-
let discr = map.apply(place, TrackElem::Discriminant)?;
607-
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
608-
return None;
609-
};
610-
let discr_bits = discr.assert_bits(discr.size());
611-
let (variant, _) =
612-
def.discriminants(tcx).find(|(_, var)| discr_bits == var.val)?;
613-
let variant_place = map.apply(place, TrackElem::Variant(variant))?;
614-
let variant_int = ty::ValTree::Leaf(variant.as_u32().into());
615-
(Some(variant_int), def.variant(variant), variant_place)
616-
} else {
617-
(None, def.non_enum_variant(), place)
632+
let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() {
633+
let discr = map.apply(place, TrackElem::Discriminant).ok_or(err_inval!(ConstPropNonsense))?;
634+
let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else {
635+
throw_inval!(ConstPropNonsense)
618636
};
619-
620-
let branches = variant_def
621-
.fields
622-
.iter_enumerated()
623-
.map(|(i, field)| {
624-
let ty = field.ty(tcx, args);
625-
let field = map.apply(variant_place, TrackElem::Field(i))?;
626-
self.try_make_valtree(field, ty, state, map)
627-
})
628-
.collect::<Option<Vec<_>>>()?;
629-
Some(ty::ValTree::Branch(
630-
tcx.arena.alloc_from_iter(variant_idx.into_iter().chain(branches)),
631-
))
637+
let discr_bits = discr.assert_bits(discr.size());
638+
let (variant, _) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val).ok_or(err_inval!(ConstPropNonsense))?;
639+
let variant_place = map.apply(place, TrackElem::Variant(variant)).ok_or(err_inval!(ConstPropNonsense))?;
640+
let variant_dest = ecx.project_downcast(dest, variant)?;
641+
(variant, def.variant(variant), variant_place, variant_dest)
642+
} else {
643+
(FIRST_VARIANT, def.non_enum_variant(), place, dest.clone())
644+
};
645+
646+
for (i, field) in variant_def.fields.iter_enumerated() {
647+
let ty = field.ty(*ecx.tcx, args);
648+
let field = map.apply(variant_place, TrackElem::Field(i)).ok_or(err_inval!(ConstPropNonsense))?;
649+
let field_dest = ecx.project_field(&variant_dest, i.as_usize())?;
650+
try_write_constant(ecx, &field_dest, field, ty, state, map)?;
632651
}
652+
ecx.write_discriminant(variant_idx, dest)?;
653+
}
633654

634-
// Do not attempt to support indirection in constants.
635-
ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) => None,
636-
637-
ty::Never
638-
| ty::Foreign(..)
639-
| ty::Alias(..)
640-
| ty::Param(_)
641-
| ty::Bound(..)
642-
| ty::Placeholder(..)
643-
| ty::Closure(..)
644-
| ty::Generator(..)
645-
| ty::Dynamic(..) => None,
646-
647-
ty::Error(_)
648-
| ty::Infer(..)
649-
| ty::GeneratorWitness(..)
650-
| ty::GeneratorWitnessMIR(..) => bug!(),
655+
// Unsupported for now.
656+
ty::Array(_, _)
657+
658+
// Do not attempt to support indirection in constants.
659+
| ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_)
660+
661+
| ty::Never
662+
| ty::Foreign(..)
663+
| ty::Alias(..)
664+
| ty::Param(_)
665+
| ty::Bound(..)
666+
| ty::Placeholder(..)
667+
| ty::Closure(..)
668+
| ty::Generator(..)
669+
| ty::Dynamic(..) => throw_inval!(ConstPropNonsense),
670+
671+
ty::Error(_) | ty::Infer(..) | ty::GeneratorWitness(..) | ty::GeneratorWitnessMIR(..) => {
672+
bug!()
651673
}
652674
}
675+
676+
Ok(())
653677
}
654678

655679
impl<'mir, 'tcx>
@@ -667,8 +691,13 @@ impl<'mir, 'tcx>
667691
) {
668692
match &statement.kind {
669693
StatementKind::Assign(box (_, rvalue)) => {
670-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
671-
.visit_rvalue(rvalue, location);
694+
OperandCollector {
695+
state,
696+
visitor: self,
697+
ecx: &mut results.analysis.0.ecx,
698+
map: &results.analysis.0.map,
699+
}
700+
.visit_rvalue(rvalue, location);
672701
}
673702
_ => (),
674703
}
@@ -686,7 +715,12 @@ impl<'mir, 'tcx>
686715
// Don't overwrite the assignment if it already uses a constant (to keep the span).
687716
}
688717
StatementKind::Assign(box (place, _)) => {
689-
if let Some(value) = self.try_make_constant(place, state, &results.analysis.0.map) {
718+
if let Some(value) = self.try_make_constant(
719+
&mut results.analysis.0.ecx,
720+
place,
721+
state,
722+
&results.analysis.0.map,
723+
) {
690724
self.patch.assignments.insert(location, value);
691725
}
692726
}
@@ -701,8 +735,13 @@ impl<'mir, 'tcx>
701735
terminator: &'mir Terminator<'tcx>,
702736
location: Location,
703737
) {
704-
OperandCollector { state, visitor: self, map: &results.analysis.0.map }
705-
.visit_terminator(terminator, location);
738+
OperandCollector {
739+
state,
740+
visitor: self,
741+
ecx: &mut results.analysis.0.ecx,
742+
map: &results.analysis.0.map,
743+
}
744+
.visit_terminator(terminator, location);
706745
}
707746
}
708747

@@ -757,6 +796,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
757796
struct OperandCollector<'tcx, 'map, 'locals, 'a> {
758797
state: &'a State<FlatSet<Scalar>>,
759798
visitor: &'a mut Collector<'tcx, 'locals>,
799+
ecx: &'map mut InterpCx<'tcx, 'tcx, DummyMachine>,
760800
map: &'map Map,
761801
}
762802

@@ -769,15 +809,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
769809
location: Location,
770810
) {
771811
if let PlaceElem::Index(local) = elem
772-
&& let Some(value) = self.visitor.try_make_constant(local.into(), self.state, self.map)
812+
&& let Some(value) = self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map)
773813
{
774814
self.visitor.patch.before_effect.insert((location, local.into()), value);
775815
}
776816
}
777817

778818
fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) {
779819
if let Some(place) = operand.place() {
780-
if let Some(value) = self.visitor.try_make_constant(place, self.state, self.map) {
820+
if let Some(value) =
821+
self.visitor.try_make_constant(self.ecx, place, self.state, self.map)
822+
{
781823
self.visitor.patch.before_effect.insert((location, place), value);
782824
} else if !place.projection.is_empty() {
783825
// Try to propagate into `Index` projections.
@@ -802,8 +844,9 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
802844
}
803845

804846
fn enforce_validity(_ecx: &InterpCx<'mir, 'tcx, Self>, _layout: TyAndLayout<'tcx>) -> bool {
805-
unimplemented!()
847+
false
806848
}
849+
807850
fn alignment_check_failed(
808851
_ecx: &InterpCx<'mir, 'tcx, Self>,
809852
_has: Align,

tests/mir-opt/const_debuginfo.main.ConstDebugInfo.diff

+6-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
let _10: std::option::Option<u16>;
4343
scope 7 {
4444
- debug o => _10;
45-
+ debug o => const Option::<u16>::Some(99);
45+
+ debug o => const Option::<u16>::Some(99_u16);
4646
let _17: u32;
4747
let _18: u32;
4848
scope 8 {
@@ -82,7 +82,7 @@
8282
_15 = const false;
8383
_16 = const 123_u32;
8484
StorageLive(_10);
85-
_10 = const Option::<u16>::Some(99);
85+
_10 = const Option::<u16>::Some(99_u16);
8686
_17 = const 32_u32;
8787
_18 = const 32_u32;
8888
StorageLive(_11);
@@ -98,3 +98,7 @@
9898
}
9999
}
100100

101+
alloc10 (size: 4, align: 2) {
102+
01 00 63 00 │ ..c.
103+
}
104+

tests/mir-opt/dataflow-const-prop/checked.main.DataflowConstProp.panic-abort.diff

+9-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
- _6 = CheckedAdd(_4, _5);
4444
- assert(!move (_6.1: bool), "attempt to compute `{} + {}`, which would overflow", move _4, move _5) -> [success: bb1, unwind unreachable];
4545
+ _5 = const 2_i32;
46-
+ _6 = const (3, false);
46+
+ _6 = const (3_i32, false);
4747
+ assert(!const false, "attempt to compute `{} + {}`, which would overflow", const 1_i32, const 2_i32) -> [success: bb1, unwind unreachable];
4848
}
4949

@@ -76,5 +76,13 @@
7676
StorageDead(_1);
7777
return;
7878
}
79+
+ }
80+
+
81+
+ alloc5 (size: 8, align: 4) {
82+
+ 00 00 00 80 01 __ __ __ │ .....░░░
83+
+ }
84+
+
85+
+ alloc4 (size: 8, align: 4) {
86+
+ 03 00 00 00 00 __ __ __ │ .....░░░
7987
}
8088

0 commit comments

Comments
 (0)