Skip to content

Commit 3e1cce7

Browse files
committed
Initial UnsafePinned impl [Part 2: Lowering]
1 parent 40dacd5 commit 3e1cce7

21 files changed

+1382
-12
lines changed

compiler/rustc_ast_ir/src/lib.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,16 @@ use rustc_macros::{Decodable_NoContext, Encodable_NoContext, HashStable_NoContex
1717
pub mod visit;
1818

1919
/// The movability of a coroutine / closure literal:
20-
/// whether a coroutine contains self-references, causing it to be `!Unpin`.
20+
/// whether a coroutine contains self-references, causing it to be `![Unsafe]Unpin`.
2121
#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Copy)]
2222
#[cfg_attr(
2323
feature = "nightly",
2424
derive(Encodable_NoContext, Decodable_NoContext, HashStable_NoContext)
2525
)]
2626
pub enum Movability {
27-
/// May contain self-references, `!Unpin`.
27+
/// May contain self-references, `!Unpin + !UnsafeUnpin`.
2828
Static,
29-
/// Must not contain self-references, `Unpin`.
29+
/// Must not contain self-references, `Unpin + UnsafeUnpin`.
3030
Movable,
3131
}
3232

compiler/rustc_borrowck/src/diagnostics/conflict_errors.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3385,7 +3385,7 @@ impl<'infcx, 'tcx> MirBorrowckCtxt<'_, 'infcx, 'tcx> {
33853385
Some(3)
33863386
} else if string.starts_with("static") {
33873387
// `static` is 6 chars long
3388-
// This is used for `!Unpin` coroutines
3388+
// This is used for immovable (self-referential) coroutines
33893389
Some(6)
33903390
} else {
33913391
None

compiler/rustc_middle/src/mir/query.rs

+3
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ pub struct CoroutineSavedTy<'tcx> {
2929
pub source_info: SourceInfo,
3030
/// Whether the local should be ignored for trait bound computations.
3131
pub ignore_for_traits: bool,
32+
/// If this local is borrowed across a suspension point and thus is
33+
/// "wrapped" in `UnsafePinned`. Always false for movable coroutines.
34+
pub pinned: bool,
3235
}
3336

3437
/// The layout of coroutine state.

compiler/rustc_middle/src/ty/context.rs

+5
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,10 @@ impl<'tcx> Interner for TyCtxt<'tcx> {
335335
self.coroutine_hidden_types(def_id)
336336
}
337337

338+
fn coroutine_has_pinned_fields(self, def_id: DefId) -> Option<bool> {
339+
self.coroutine_has_pinned_fields(def_id)
340+
}
341+
338342
fn fn_sig(self, def_id: DefId) -> ty::EarlyBinder<'tcx, ty::PolyFnSig<'tcx>> {
339343
self.fn_sig(def_id)
340344
}
@@ -734,6 +738,7 @@ bidirectional_lang_item_map! {
734738
TransmuteTrait,
735739
Tuple,
736740
Unpin,
741+
UnsafeUnpin,
737742
Unsize,
738743
// tidy-alphabetical-end
739744
}

compiler/rustc_middle/src/ty/util.rs

+7
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,13 @@ impl<'tcx> TyCtxt<'tcx> {
790790
))
791791
}
792792

793+
/// True if the given coroutine has any pinned fields.
794+
/// `None` if the coroutine is tainted by errors.
795+
pub fn coroutine_has_pinned_fields(self, def_id: DefId) -> Option<bool> {
796+
self.mir_coroutine_witnesses(def_id)
797+
.map(|layout| layout.field_tys.iter().any(|ty| ty.pinned))
798+
}
799+
793800
/// Expands the given impl trait type, stopping if the type is recursive.
794801
#[instrument(skip(self), level = "debug", ret)]
795802
pub fn try_expand_impl_trait_type(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
use rustc_index::bit_set::DenseBitSet;
2+
use rustc_middle::mir::visit::Visitor;
3+
use rustc_middle::mir::*;
4+
use tracing::debug;
5+
6+
use crate::{Analysis, GenKill};
7+
8+
#[derive(Clone)]
9+
pub struct CoroutinePinnedLocals(pub Local);
10+
11+
impl CoroutinePinnedLocals {
12+
fn transfer_function<'a>(&self, domain: &'a mut DenseBitSet<Local>) -> TransferFunction<'a> {
13+
TransferFunction { local: self.0, trans: domain }
14+
}
15+
}
16+
17+
impl<'tcx> Analysis<'tcx> for CoroutinePinnedLocals {
18+
type Domain = DenseBitSet<Local>;
19+
const NAME: &'static str = "coro_pinned_locals";
20+
21+
fn bottom_value(&self, body: &Body<'tcx>) -> Self::Domain {
22+
// bottom = unborrowed
23+
DenseBitSet::new_empty(body.local_decls().len())
24+
}
25+
26+
fn initialize_start_block(&self, _: &Body<'tcx>, _: &mut Self::Domain) {
27+
// No locals are actively borrowing from other locals on function entry
28+
}
29+
30+
fn apply_primary_statement_effect(
31+
&mut self,
32+
state: &mut Self::Domain,
33+
statement: &Statement<'tcx>,
34+
location: Location,
35+
) {
36+
self.transfer_function(state).visit_statement(statement, location);
37+
}
38+
39+
fn apply_primary_terminator_effect<'mir>(
40+
&mut self,
41+
state: &mut Self::Domain,
42+
terminator: &'mir Terminator<'tcx>,
43+
location: Location,
44+
) -> TerminatorEdges<'mir, 'tcx> {
45+
self.transfer_function(state).visit_terminator(terminator, location);
46+
47+
terminator.edges()
48+
}
49+
}
50+
51+
/// A `Visitor` that defines the transfer function for `CoroutinePinnedLocals`.
52+
pub(super) struct TransferFunction<'a> {
53+
local: Local,
54+
trans: &'a mut DenseBitSet<Local>,
55+
}
56+
57+
impl<'tcx> Visitor<'tcx> for TransferFunction<'_> {
58+
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
59+
self.super_statement(statement, location);
60+
61+
if let StatementKind::StorageDead(local) = statement.kind {
62+
debug!(for_ = ?self.local, KILL = ?local, ?statement, ?location);
63+
self.trans.kill(local);
64+
}
65+
}
66+
67+
fn visit_assign(
68+
&mut self,
69+
assigned_place: &Place<'tcx>,
70+
rvalue: &Rvalue<'tcx>,
71+
location: Location,
72+
) {
73+
self.super_assign(assigned_place, rvalue, location);
74+
75+
match rvalue {
76+
Rvalue::Ref(_, BorrowKind::Mut { .. } | BorrowKind::Shared, place)
77+
| Rvalue::RawPtr(RawPtrKind::Const | RawPtrKind::Mut, place) => {
78+
if (!place.is_indirect() && place.local == self.local)
79+
|| self.trans.contains(place.local)
80+
{
81+
if assigned_place.is_indirect() {
82+
debug!(for_ = ?self.local, GEN_ptr_indirect = ?assigned_place, borrowed_place = ?place, ?rvalue, ?location);
83+
self.trans.gen_(self.local);
84+
} else {
85+
debug!(for_ = ?self.local, GEN_ptr_direct = ?assigned_place, borrowed_place = ?place, ?rvalue, ?location);
86+
self.trans.gen_(assigned_place.local);
87+
}
88+
}
89+
}
90+
91+
// fake pointers don't count
92+
Rvalue::Ref(_, BorrowKind::Fake(_), _)
93+
| Rvalue::RawPtr(RawPtrKind::FakeForPtrMetadata, _) => {}
94+
95+
Rvalue::Use(..)
96+
| Rvalue::Repeat(..)
97+
| Rvalue::ThreadLocalRef(..)
98+
| Rvalue::Len(..)
99+
| Rvalue::Cast(..)
100+
| Rvalue::BinaryOp(..)
101+
| Rvalue::NullaryOp(..)
102+
| Rvalue::UnaryOp(..)
103+
| Rvalue::Discriminant(..)
104+
| Rvalue::Aggregate(..)
105+
| Rvalue::ShallowInitBox(..)
106+
| Rvalue::CopyForDeref(..)
107+
| Rvalue::WrapUnsafeBinder(..) => {}
108+
}
109+
}
110+
111+
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
112+
self.super_terminator(terminator, location);
113+
114+
match terminator.kind {
115+
TerminatorKind::Drop { place: dropped_place, .. } => {
116+
// Drop terminators may call custom drop glue (`Drop::drop`), which takes `&mut
117+
// self` as a parameter. In the general case, a drop impl could launder that
118+
// reference into the surrounding environment through a raw pointer, thus creating
119+
// a valid `*mut` pointing to the dropped local. We are not yet willing to declare
120+
// this particular case UB, so we must treat all dropped locals as mutably borrowed
121+
// for now. See discussion on [#61069].
122+
//
123+
// [#61069]: https://github.com/rust-lang/rust/pull/61069
124+
if !dropped_place.is_indirect() && dropped_place.local == self.local {
125+
debug!(for_ = ?self.local, GEN_drop = ?dropped_place, ?terminator, ?location);
126+
self.trans.gen_(self.local);
127+
}
128+
}
129+
130+
TerminatorKind::Goto { .. }
131+
| TerminatorKind::SwitchInt { .. }
132+
| TerminatorKind::UnwindResume
133+
| TerminatorKind::UnwindTerminate(_)
134+
| TerminatorKind::Return
135+
| TerminatorKind::Unreachable
136+
| TerminatorKind::Call { .. }
137+
| TerminatorKind::TailCall { .. }
138+
| TerminatorKind::Assert { .. }
139+
| TerminatorKind::Yield { .. }
140+
| TerminatorKind::CoroutineDrop
141+
| TerminatorKind::FalseEdge { .. }
142+
| TerminatorKind::FalseUnwind { .. }
143+
| TerminatorKind::InlineAsm { .. } => {}
144+
}
145+
}
146+
}

compiler/rustc_mir_dataflow/src/impls/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
mod borrowed_locals;
2+
mod coro_pinned_locals;
23
mod initialized;
34
mod liveness;
45
mod storage_liveness;
56

67
pub use self::borrowed_locals::{MaybeBorrowedLocals, borrowed_locals};
8+
pub use self::coro_pinned_locals::CoroutinePinnedLocals;
79
pub use self::initialized::{
810
EverInitializedPlaces, EverInitializedPlacesDomain, MaybeInitializedPlaces,
911
MaybeUninitializedPlaces, MaybeUninitializedPlacesDomain,

compiler/rustc_mir_transform/src/coroutine.rs

+52-5
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ use rustc_middle::ty::{
6969
};
7070
use rustc_middle::{bug, span_bug};
7171
use rustc_mir_dataflow::impls::{
72-
MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
73-
always_storage_live_locals,
72+
CoroutinePinnedLocals, MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage,
73+
MaybeStorageLive, always_storage_live_locals,
7474
};
7575
use rustc_mir_dataflow::{Analysis, Results, ResultsVisitor};
7676
use rustc_span::def_id::{DefId, LocalDefId};
@@ -639,6 +639,15 @@ struct LivenessInfo {
639639
/// Parallel vec to the above with SourceInfo for each yield terminator.
640640
source_info_at_suspension_points: Vec<SourceInfo>,
641641

642+
/// Coroutine saved locals that are borrowed across a suspension point.
643+
/// This corresponds to locals that are "wrapped" with `UnsafePinned`.
644+
///
645+
/// Note that movable coroutines do not allow borrowing locals across
646+
/// suspension points and thus will always have this set empty.
647+
///
648+
/// For more information, see [RFC 3467](https://rust-lang.github.io/rfcs/3467-unsafe-pinned.html).
649+
saved_locals_borrowed_across_suspension_points: DenseBitSet<CoroutineSavedLocal>,
650+
642651
/// For every saved local, the set of other saved locals that are
643652
/// storage-live at the same time as this local. We cannot overlap locals in
644653
/// the layout which have conflicting storage.
@@ -657,6 +666,9 @@ struct LivenessInfo {
657666
/// case none exist, the local is considered to be always live.
658667
/// - a local has to be stored if it is either directly used after the
659668
/// the suspend point, or if it is live and has been previously borrowed.
669+
///
670+
/// We also compute locals which are "pinned" (borrowed across a suspension point).
671+
/// These are "wrapped" in `UnsafePinned` and have their niche opts disabled.
660672
fn locals_live_across_suspend_points<'tcx>(
661673
tcx: TyCtxt<'tcx>,
662674
body: &Body<'tcx>,
@@ -686,10 +698,12 @@ fn locals_live_across_suspend_points<'tcx>(
686698
let mut liveness =
687699
MaybeLiveLocals.iterate_to_fixpoint(tcx, body, Some("coroutine")).into_results_cursor(body);
688700

701+
let mut pinned_locals_cache = IndexVec::from_fn_n(|_| None, body.local_decls.len());
689702
let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks);
690703
let mut live_locals_at_suspension_points = Vec::new();
691704
let mut source_info_at_suspension_points = Vec::new();
692705
let mut live_locals_at_any_suspension_point = DenseBitSet::new_empty(body.local_decls.len());
706+
let mut pinned_locals = DenseBitSet::new_empty(body.local_decls.len());
693707

694708
for (block, data) in body.basic_blocks.iter_enumerated() {
695709
if let TerminatorKind::Yield { .. } = data.terminator().kind {
@@ -729,6 +743,27 @@ fn locals_live_across_suspend_points<'tcx>(
729743

730744
debug!("loc = {:?}, live_locals = {:?}", loc, live_locals);
731745

746+
for live_local in live_locals.iter() {
747+
let pinned_cursor = pinned_locals_cache[live_local].get_or_insert_with(|| {
748+
CoroutinePinnedLocals(live_local)
749+
.iterate_to_fixpoint(tcx, body, None)
750+
.into_results_cursor(body)
751+
});
752+
pinned_cursor.seek_to_block_end(block);
753+
let mut pinned_by = pinned_cursor.get().clone();
754+
pinned_by.intersect(&live_locals);
755+
756+
if !pinned_by.is_empty() {
757+
assert!(
758+
!movable,
759+
"local {live_local:?} of movable coro shouldn't be pinned, yet it is pinned by {pinned_by:?}"
760+
);
761+
762+
debug!("{live_local:?} pinned by {pinned_by:?} in {block:?}");
763+
pinned_locals.insert(live_local);
764+
}
765+
}
766+
732767
// Add the locals live at this suspension point to the set of locals which live across
733768
// any suspension points
734769
live_locals_at_any_suspension_point.union(&live_locals);
@@ -738,7 +773,8 @@ fn locals_live_across_suspend_points<'tcx>(
738773
}
739774
}
740775

741-
debug!("live_locals_anywhere = {:?}", live_locals_at_any_suspension_point);
776+
debug!(?pinned_locals);
777+
debug!(live_locals_anywhere = ?live_locals_at_any_suspension_point);
742778
let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point);
743779

744780
// Renumber our liveness_map bitsets to include only the locals we are
@@ -748,6 +784,9 @@ fn locals_live_across_suspend_points<'tcx>(
748784
.map(|live_here| saved_locals.renumber_bitset(live_here))
749785
.collect();
750786

787+
let saved_locals_borrowed_across_suspension_points =
788+
saved_locals.renumber_bitset(&pinned_locals);
789+
751790
let storage_conflicts = compute_storage_conflicts(
752791
body,
753792
&saved_locals,
@@ -759,6 +798,7 @@ fn locals_live_across_suspend_points<'tcx>(
759798
saved_locals,
760799
live_locals_at_suspension_points,
761800
source_info_at_suspension_points,
801+
saved_locals_borrowed_across_suspension_points,
762802
storage_conflicts,
763803
storage_liveness: storage_liveness_map,
764804
}
@@ -931,6 +971,7 @@ fn compute_layout<'tcx>(
931971
saved_locals,
932972
live_locals_at_suspension_points,
933973
source_info_at_suspension_points,
974+
saved_locals_borrowed_across_suspension_points,
934975
storage_conflicts,
935976
storage_liveness,
936977
} = liveness;
@@ -960,8 +1001,14 @@ fn compute_layout<'tcx>(
9601001
ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true,
9611002
_ => false,
9621003
};
963-
let decl =
964-
CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
1004+
let pinned = saved_locals_borrowed_across_suspension_points.contains(saved_local);
1005+
1006+
let decl = CoroutineSavedTy {
1007+
ty: decl.ty,
1008+
source_info: decl.source_info,
1009+
ignore_for_traits,
1010+
pinned,
1011+
};
9651012
debug!(?decl);
9661013

9671014
tys.push(decl);

0 commit comments

Comments
 (0)