3
3
//! Currently, this pass only propagates scalar values.
4
4
5
5
use rustc_const_eval:: const_eval:: CheckAlignment ;
6
- use rustc_const_eval:: interpret:: { ImmTy , Immediate , InterpCx , OpTy , Projectable } ;
6
+ use rustc_const_eval:: interpret:: { ImmTy , Immediate , InterpCx , OpTy , PlaceTy , Projectable } ;
7
7
use rustc_data_structures:: fx:: FxHashMap ;
8
8
use rustc_hir:: def:: DefKind ;
9
9
use rustc_middle:: mir:: interpret:: { AllocId , ConstAllocation , InterpResult , Scalar } ;
10
10
use rustc_middle:: mir:: visit:: { MutVisitor , PlaceContext , Visitor } ;
11
11
use rustc_middle:: mir:: * ;
12
- use rustc_middle:: ty:: layout:: TyAndLayout ;
12
+ use rustc_middle:: ty:: layout:: { LayoutOf , TyAndLayout } ;
13
13
use rustc_middle:: ty:: { self , Ty , TyCtxt } ;
14
14
use rustc_mir_dataflow:: value_analysis:: {
15
15
Map , PlaceIndex , State , TrackElem , ValueAnalysis , ValueAnalysisWrapper , ValueOrPlace ,
16
16
} ;
17
17
use rustc_mir_dataflow:: { lattice:: FlatSet , Analysis , Results , ResultsVisitor } ;
18
18
use rustc_span:: def_id:: DefId ;
19
19
use rustc_span:: DUMMY_SP ;
20
- use rustc_target:: abi:: { Align , FieldIdx , VariantIdx } ;
20
+ use rustc_target:: abi:: { Align , FieldIdx , Size , VariantIdx , FIRST_VARIANT } ;
21
21
22
22
use crate :: MirPass ;
23
23
@@ -546,110 +546,134 @@ impl<'tcx, 'locals> Collector<'tcx, 'locals> {
546
546
547
547
fn try_make_constant (
548
548
& self ,
549
+ ecx : & mut InterpCx < ' tcx , ' tcx , DummyMachine > ,
549
550
place : Place < ' tcx > ,
550
551
state : & State < FlatSet < Scalar > > ,
551
552
map : & Map ,
552
553
) -> Option < ConstantKind < ' tcx > > {
553
554
let ty = place. ty ( self . local_decls , self . patch . tcx ) . ty ;
554
555
let place = map. find ( place. as_ref ( ) ) ?;
555
- if let FlatSet :: Elem ( Scalar :: Int ( value) ) = state. get_idx ( place, map) {
556
- Some ( ConstantKind :: Val ( ConstValue :: Scalar ( value. into ( ) ) , ty) )
556
+ let layout = ecx. layout_of ( ty) . ok ( ) ?;
557
+ if layout. is_zst ( ) {
558
+ Some ( ConstantKind :: Val ( ConstValue :: ZeroSized , ty) )
559
+ } else if layout. abi . is_scalar ( )
560
+ && let Some ( value) = propagatable_scalar ( place, state, map)
561
+ {
562
+ Some ( ConstantKind :: Val ( ConstValue :: Scalar ( value) , ty) )
563
+ } else if layout. is_sized ( ) && layout. size <= 4 * ecx. tcx . data_layout . pointer_size {
564
+ let alloc_id = ecx
565
+ . intern_with_temp_alloc ( layout, |ecx, dest| {
566
+ try_write_constant ( ecx, dest, place, ty, state, map)
567
+ } )
568
+ . ok ( ) ?;
569
+ Some ( ConstantKind :: Val ( ConstValue :: Indirect { alloc_id, offset : Size :: ZERO } , ty) )
557
570
} else {
558
- let valtree = self . try_make_valtree ( place, ty, state, map) ?;
559
- let constant = ty:: Const :: new_value ( self . patch . tcx , valtree, ty) ;
560
- Some ( ConstantKind :: Ty ( constant) )
571
+ None
561
572
}
562
573
}
574
+ }
563
575
564
- fn try_make_valtree (
565
- & self ,
566
- place : PlaceIndex ,
567
- ty : Ty < ' tcx > ,
568
- state : & State < FlatSet < Scalar > > ,
569
- map : & Map ,
570
- ) -> Option < ty:: ValTree < ' tcx > > {
571
- let tcx = self . patch . tcx ;
572
- match ty. kind ( ) {
573
- // ZSTs.
574
- ty:: FnDef ( ..) => Some ( ty:: ValTree :: zst ( ) ) ,
575
-
576
- // Scalars.
577
- ty:: Bool | ty:: Int ( _) | ty:: Uint ( _) | ty:: Float ( _) | ty:: Char => {
578
- if let FlatSet :: Elem ( Scalar :: Int ( value) ) = state. get_idx ( place, map) {
579
- Some ( ty:: ValTree :: Leaf ( value) )
580
- } else {
581
- None
582
- }
583
- }
576
+ fn propagatable_scalar (
577
+ place : PlaceIndex ,
578
+ state : & State < FlatSet < Scalar > > ,
579
+ map : & Map ,
580
+ ) -> Option < Scalar > {
581
+ if let FlatSet :: Elem ( value) = state. get_idx ( place, map) && value. try_to_int ( ) . is_ok ( ) {
582
+ // Do not attempt to propagate pointers, as we may fail to preserve their identity.
583
+ Some ( value)
584
+ } else {
585
+ None
586
+ }
587
+ }
588
+
589
+ #[ instrument( level = "trace" , skip( ecx, state, map) ) ]
590
+ fn try_write_constant < ' tcx > (
591
+ ecx : & mut InterpCx < ' _ , ' tcx , DummyMachine > ,
592
+ dest : & PlaceTy < ' tcx > ,
593
+ place : PlaceIndex ,
594
+ ty : Ty < ' tcx > ,
595
+ state : & State < FlatSet < Scalar > > ,
596
+ map : & Map ,
597
+ ) -> InterpResult < ' tcx > {
598
+ let layout = ecx. layout_of ( ty) ?;
599
+
600
+ // Fast path for ZSTs.
601
+ if layout. is_zst ( ) {
602
+ return Ok ( ( ) ) ;
603
+ }
604
+
605
+ // Fast path for scalars.
606
+ if layout. abi . is_scalar ( )
607
+ && let Some ( value) = propagatable_scalar ( place, state, map)
608
+ {
609
+ return ecx. write_immediate ( Immediate :: Scalar ( value) , dest) ;
610
+ }
584
611
585
- // Unsupported for now.
586
- ty:: Array ( _, _) => None ,
587
-
588
- ty:: Tuple ( elem_tys) => {
589
- let branches = elem_tys
590
- . iter ( )
591
- . enumerate ( )
592
- . map ( |( i, ty) | {
593
- let field = map. apply ( place, TrackElem :: Field ( FieldIdx :: from_usize ( i) ) ) ?;
594
- self . try_make_valtree ( field, ty, state, map)
595
- } )
596
- . collect :: < Option < Vec < _ > > > ( ) ?;
597
- Some ( ty:: ValTree :: Branch ( tcx. arena . alloc_from_iter ( branches. into_iter ( ) ) ) )
612
+ match ty. kind ( ) {
613
+ // ZSTs. Nothing to do.
614
+ ty:: FnDef ( ..) => { }
615
+
616
+ // Those are scalars, must be handled above.
617
+ ty:: Bool | ty:: Int ( _) | ty:: Uint ( _) | ty:: Float ( _) | ty:: Char => throw_inval ! ( ConstPropNonsense ) ,
618
+
619
+ ty:: Tuple ( elem_tys) => {
620
+ for ( i, elem) in elem_tys. iter ( ) . enumerate ( ) {
621
+ let field = map. apply ( place, TrackElem :: Field ( FieldIdx :: from_usize ( i) ) ) . ok_or ( err_inval ! ( ConstPropNonsense ) ) ?;
622
+ let field_dest = ecx. project_field ( dest, i) ?;
623
+ try_write_constant ( ecx, & field_dest, field, elem, state, map) ?;
598
624
}
625
+ }
599
626
600
- ty:: Adt ( def, args) => {
601
- if def. is_union ( ) {
602
- return None ;
603
- }
627
+ ty:: Adt ( def, args) => {
628
+ if def. is_union ( ) {
629
+ throw_inval ! ( ConstPropNonsense )
630
+ }
604
631
605
- let ( variant_idx, variant_def, variant_place) = if def. is_enum ( ) {
606
- let discr = map. apply ( place, TrackElem :: Discriminant ) ?;
607
- let FlatSet :: Elem ( Scalar :: Int ( discr) ) = state. get_idx ( discr, map) else {
608
- return None ;
609
- } ;
610
- let discr_bits = discr. assert_bits ( discr. size ( ) ) ;
611
- let ( variant, _) =
612
- def. discriminants ( tcx) . find ( |( _, var) | discr_bits == var. val ) ?;
613
- let variant_place = map. apply ( place, TrackElem :: Variant ( variant) ) ?;
614
- let variant_int = ty:: ValTree :: Leaf ( variant. as_u32 ( ) . into ( ) ) ;
615
- ( Some ( variant_int) , def. variant ( variant) , variant_place)
616
- } else {
617
- ( None , def. non_enum_variant ( ) , place)
632
+ let ( variant_idx, variant_def, variant_place, variant_dest) = if def. is_enum ( ) {
633
+ let discr = map. apply ( place, TrackElem :: Discriminant ) . ok_or ( err_inval ! ( ConstPropNonsense ) ) ?;
634
+ let FlatSet :: Elem ( Scalar :: Int ( discr) ) = state. get_idx ( discr, map) else {
635
+ throw_inval ! ( ConstPropNonsense )
618
636
} ;
619
-
620
- let branches = variant_def
621
- . fields
622
- . iter_enumerated ( )
623
- . map ( |( i, field) | {
624
- let ty = field. ty ( tcx, args) ;
625
- let field = map. apply ( variant_place, TrackElem :: Field ( i) ) ?;
626
- self . try_make_valtree ( field, ty, state, map)
627
- } )
628
- . collect :: < Option < Vec < _ > > > ( ) ?;
629
- Some ( ty:: ValTree :: Branch (
630
- tcx. arena . alloc_from_iter ( variant_idx. into_iter ( ) . chain ( branches) ) ,
631
- ) )
637
+ let discr_bits = discr. assert_bits ( discr. size ( ) ) ;
638
+ let ( variant, _) = def. discriminants ( * ecx. tcx ) . find ( |( _, var) | discr_bits == var. val ) . ok_or ( err_inval ! ( ConstPropNonsense ) ) ?;
639
+ let variant_place = map. apply ( place, TrackElem :: Variant ( variant) ) . ok_or ( err_inval ! ( ConstPropNonsense ) ) ?;
640
+ let variant_dest = ecx. project_downcast ( dest, variant) ?;
641
+ ( variant, def. variant ( variant) , variant_place, variant_dest)
642
+ } else {
643
+ ( FIRST_VARIANT , def. non_enum_variant ( ) , place, dest. clone ( ) )
644
+ } ;
645
+
646
+ for ( i, field) in variant_def. fields . iter_enumerated ( ) {
647
+ let ty = field. ty ( * ecx. tcx , args) ;
648
+ let field = map. apply ( variant_place, TrackElem :: Field ( i) ) . ok_or ( err_inval ! ( ConstPropNonsense ) ) ?;
649
+ let field_dest = ecx. project_field ( & variant_dest, i. as_usize ( ) ) ?;
650
+ try_write_constant ( ecx, & field_dest, field, ty, state, map) ?;
632
651
}
652
+ ecx. write_discriminant ( variant_idx, dest) ?;
653
+ }
633
654
634
- // Do not attempt to support indirection in constants.
635
- ty:: Ref ( ..) | ty:: RawPtr ( ..) | ty:: FnPtr ( ..) | ty:: Str | ty:: Slice ( _) => None ,
636
-
637
- ty:: Never
638
- | ty:: Foreign ( ..)
639
- | ty:: Alias ( ..)
640
- | ty:: Param ( _)
641
- | ty:: Bound ( ..)
642
- | ty:: Placeholder ( ..)
643
- | ty:: Closure ( ..)
644
- | ty:: Generator ( ..)
645
- | ty:: Dynamic ( ..) => None ,
646
-
647
- ty:: Error ( _)
648
- | ty:: Infer ( ..)
649
- | ty:: GeneratorWitness ( ..)
650
- | ty:: GeneratorWitnessMIR ( ..) => bug ! ( ) ,
655
+ // Unsupported for now.
656
+ ty:: Array ( _, _)
657
+
658
+ // Do not attempt to support indirection in constants.
659
+ | ty:: Ref ( ..) | ty:: RawPtr ( ..) | ty:: FnPtr ( ..) | ty:: Str | ty:: Slice ( _)
660
+
661
+ | ty:: Never
662
+ | ty:: Foreign ( ..)
663
+ | ty:: Alias ( ..)
664
+ | ty:: Param ( _)
665
+ | ty:: Bound ( ..)
666
+ | ty:: Placeholder ( ..)
667
+ | ty:: Closure ( ..)
668
+ | ty:: Generator ( ..)
669
+ | ty:: Dynamic ( ..) => throw_inval ! ( ConstPropNonsense ) ,
670
+
671
+ ty:: Error ( _) | ty:: Infer ( ..) | ty:: GeneratorWitness ( ..) | ty:: GeneratorWitnessMIR ( ..) => {
672
+ bug ! ( )
651
673
}
652
674
}
675
+
676
+ Ok ( ( ) )
653
677
}
654
678
655
679
impl < ' mir , ' tcx >
@@ -667,8 +691,13 @@ impl<'mir, 'tcx>
667
691
) {
668
692
match & statement. kind {
669
693
StatementKind :: Assign ( box ( _, rvalue) ) => {
670
- OperandCollector { state, visitor : self , map : & results. analysis . 0 . map }
671
- . visit_rvalue ( rvalue, location) ;
694
+ OperandCollector {
695
+ state,
696
+ visitor : self ,
697
+ ecx : & mut results. analysis . 0 . ecx ,
698
+ map : & results. analysis . 0 . map ,
699
+ }
700
+ . visit_rvalue ( rvalue, location) ;
672
701
}
673
702
_ => ( ) ,
674
703
}
@@ -686,7 +715,12 @@ impl<'mir, 'tcx>
686
715
// Don't overwrite the assignment if it already uses a constant (to keep the span).
687
716
}
688
717
StatementKind :: Assign ( box ( place, _) ) => {
689
- if let Some ( value) = self . try_make_constant ( place, state, & results. analysis . 0 . map ) {
718
+ if let Some ( value) = self . try_make_constant (
719
+ & mut results. analysis . 0 . ecx ,
720
+ place,
721
+ state,
722
+ & results. analysis . 0 . map ,
723
+ ) {
690
724
self . patch . assignments . insert ( location, value) ;
691
725
}
692
726
}
@@ -701,8 +735,13 @@ impl<'mir, 'tcx>
701
735
terminator : & ' mir Terminator < ' tcx > ,
702
736
location : Location ,
703
737
) {
704
- OperandCollector { state, visitor : self , map : & results. analysis . 0 . map }
705
- . visit_terminator ( terminator, location) ;
738
+ OperandCollector {
739
+ state,
740
+ visitor : self ,
741
+ ecx : & mut results. analysis . 0 . ecx ,
742
+ map : & results. analysis . 0 . map ,
743
+ }
744
+ . visit_terminator ( terminator, location) ;
706
745
}
707
746
}
708
747
@@ -757,6 +796,7 @@ impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> {
757
796
struct OperandCollector < ' tcx , ' map , ' locals , ' a > {
758
797
state : & ' a State < FlatSet < Scalar > > ,
759
798
visitor : & ' a mut Collector < ' tcx , ' locals > ,
799
+ ecx : & ' map mut InterpCx < ' tcx , ' tcx , DummyMachine > ,
760
800
map : & ' map Map ,
761
801
}
762
802
@@ -769,15 +809,17 @@ impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> {
769
809
location : Location ,
770
810
) {
771
811
if let PlaceElem :: Index ( local) = elem
772
- && let Some ( value) = self . visitor . try_make_constant ( local. into ( ) , self . state , self . map )
812
+ && let Some ( value) = self . visitor . try_make_constant ( self . ecx , local. into ( ) , self . state , self . map )
773
813
{
774
814
self . visitor . patch . before_effect . insert ( ( location, local. into ( ) ) , value) ;
775
815
}
776
816
}
777
817
778
818
fn visit_operand ( & mut self , operand : & Operand < ' tcx > , location : Location ) {
779
819
if let Some ( place) = operand. place ( ) {
780
- if let Some ( value) = self . visitor . try_make_constant ( place, self . state , self . map ) {
820
+ if let Some ( value) =
821
+ self . visitor . try_make_constant ( self . ecx , place, self . state , self . map )
822
+ {
781
823
self . visitor . patch . before_effect . insert ( ( location, place) , value) ;
782
824
} else if !place. projection . is_empty ( ) {
783
825
// Try to propagate into `Index` projections.
@@ -802,8 +844,9 @@ impl<'mir, 'tcx: 'mir> rustc_const_eval::interpret::Machine<'mir, 'tcx> for Dumm
802
844
}
803
845
804
846
fn enforce_validity ( _ecx : & InterpCx < ' mir , ' tcx , Self > , _layout : TyAndLayout < ' tcx > ) -> bool {
805
- unimplemented ! ( )
847
+ false
806
848
}
849
+
807
850
fn alignment_check_failed (
808
851
_ecx : & InterpCx < ' mir , ' tcx , Self > ,
809
852
_has : Align ,
0 commit comments