Skip to content

Commit 12ae3f2

Browse files
Use a helper to zip together parent and child captures for coroutine-closures
1 parent e908cfd commit 12ae3f2

File tree

3 files changed

+73
-64
lines changed

3 files changed

+73
-64
lines changed

compiler/rustc_middle/src/ty/closure.rs

+59
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::{mir, ty};
66
use std::fmt::Write;
77

88
use crate::query::Providers;
9+
use rustc_data_structures::captures::Captures;
910
use rustc_data_structures::fx::FxIndexMap;
1011
use rustc_hir as hir;
1112
use rustc_hir::def_id::LocalDefId;
@@ -415,6 +416,64 @@ impl BorrowKind {
415416
}
416417
}
417418

419+
pub fn analyze_coroutine_closure_captures<'a, 'tcx: 'a, T>(
420+
parent_captures: impl IntoIterator<Item = &'a CapturedPlace<'tcx>>,
421+
child_captures: impl IntoIterator<Item = &'a CapturedPlace<'tcx>>,
422+
mut for_each: impl FnMut((usize, &'a CapturedPlace<'tcx>), (usize, &'a CapturedPlace<'tcx>)) -> T,
423+
) -> impl Iterator<Item = T> + Captures<'a> + Captures<'tcx> {
424+
std::iter::from_coroutine(move || {
425+
let mut child_captures = child_captures.into_iter().enumerate().peekable();
426+
427+
// One parent capture may correspond to several child captures if we end up
428+
// refining the set of captures via edition-2021 precise captures. We want to
429+
// match up any number of child captures with one parent capture, so we keep
430+
// peeking off this `Peekable` until the child doesn't match anymore.
431+
for (parent_field_idx, parent_capture) in parent_captures.into_iter().enumerate() {
432+
// Make sure we use every field at least once, b/c why are we capturing something
433+
// if it's not used in the inner coroutine.
434+
let mut field_used_at_least_once = false;
435+
436+
// A parent matches a child if they share the same prefix of projections.
437+
// The child may have more, if it is capturing sub-fields out of
438+
// something that is captured by-move in the parent closure.
439+
while child_captures.peek().map_or(false, |(_, child_capture)| {
440+
child_prefix_matches_parent_projections(parent_capture, child_capture)
441+
}) {
442+
let (child_field_idx, child_capture) = child_captures.next().unwrap();
443+
yield for_each(
444+
(parent_field_idx, parent_capture),
445+
(child_field_idx, child_capture),
446+
);
447+
field_used_at_least_once = true;
448+
}
449+
450+
// Make sure the field was used at least once.
451+
assert!(
452+
field_used_at_least_once,
453+
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
454+
);
455+
}
456+
assert_eq!(child_captures.next(), None, "leftover child captures?");
457+
})
458+
}
459+
460+
fn child_prefix_matches_parent_projections(
461+
parent_capture: &ty::CapturedPlace<'_>,
462+
child_capture: &ty::CapturedPlace<'_>,
463+
) -> bool {
464+
let HirPlaceBase::Upvar(parent_base) = parent_capture.place.base else {
465+
bug!("expected capture to be an upvar");
466+
};
467+
let HirPlaceBase::Upvar(child_base) = child_capture.place.base else {
468+
bug!("expected capture to be an upvar");
469+
};
470+
471+
assert!(child_capture.place.projections.len() >= parent_capture.place.projections.len());
472+
parent_base.var_path.hir_id == child_base.var_path.hir_id
473+
&& std::iter::zip(&child_capture.place.projections, &parent_capture.place.projections)
474+
.all(|(child, parent)| child.kind == parent.kind)
475+
}
476+
418477
pub fn provide(providers: &mut Providers) {
419478
*providers = Providers { closure_typeinfo, ..*providers }
420479
}

compiler/rustc_middle/src/ty/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,10 @@ pub use rustc_type_ir::ConstKind::{
7777
pub use rustc_type_ir::*;
7878

7979
pub use self::closure::{
80-
is_ancestor_or_same_capture, place_to_string_for_capture, BorrowKind, CaptureInfo,
81-
CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap, MinCaptureList,
82-
RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath, CAPTURE_STRUCT_LOCAL,
80+
analyze_coroutine_closure_captures, is_ancestor_or_same_capture, place_to_string_for_capture,
81+
BorrowKind, CaptureInfo, CapturedPlace, ClosureTypeInfo, MinCaptureInformationMap,
82+
MinCaptureList, RootVariableMinCaptureList, UpvarCapture, UpvarId, UpvarPath,
83+
CAPTURE_STRUCT_LOCAL,
8384
};
8485
pub use self::consts::{
8586
Const, ConstData, ConstInt, ConstKind, Expr, ScalarInt, UnevaluatedConst, ValTree,

compiler/rustc_mir_transform/src/coroutine/by_move_body.rs

+10-61
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
7272
use rustc_data_structures::unord::UnordMap;
7373
use rustc_hir as hir;
74-
use rustc_middle::hir::place::{PlaceBase, Projection, ProjectionKind};
74+
use rustc_middle::hir::place::{Projection, ProjectionKind};
7575
use rustc_middle::mir::visit::MutVisitor;
7676
use rustc_middle::mir::{self, dump_mir, MirPass};
7777
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt, TypeVisitableExt};
@@ -124,36 +124,10 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
124124
.tuple_fields()
125125
.len();
126126

127-
let mut field_remapping = UnordMap::default();
128-
129-
let mut child_captures = tcx
130-
.closure_captures(coroutine_def_id)
131-
.iter()
132-
.copied()
133-
// By construction we capture all the args first.
134-
.skip(num_args)
135-
.enumerate()
136-
.peekable();
137-
138-
// One parent capture may correspond to several child captures if we end up
139-
// refining the set of captures via edition-2021 precise captures. We want to
140-
// match up any number of child captures with one parent capture, so we keep
141-
// peeking off this `Peekable` until the child doesn't match anymore.
142-
for (parent_field_idx, parent_capture) in
143-
tcx.closure_captures(parent_def_id).iter().copied().enumerate()
144-
{
145-
// Make sure we use every field at least once, b/c why are we capturing something
146-
// if it's not used in the inner coroutine.
147-
let mut field_used_at_least_once = false;
148-
149-
// A parent matches a child if they share the same prefix of projections.
150-
// The child may have more, if it is capturing sub-fields out of
151-
// something that is captured by-move in the parent closure.
152-
while child_captures.peek().map_or(false, |(_, child_capture)| {
153-
child_prefix_matches_parent_projections(parent_capture, child_capture)
154-
}) {
155-
let (child_field_idx, child_capture) = child_captures.next().unwrap();
156-
127+
let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures(
128+
tcx.closure_captures(parent_def_id).iter().copied(),
129+
tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(),
130+
|(parent_field_idx, parent_capture), (child_field_idx, child_capture)| {
157131
// Store this set of additional projections (fields and derefs).
158132
// We need to re-apply them later.
159133
let child_precise_captures =
@@ -184,26 +158,18 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
184158
),
185159
};
186160

187-
field_remapping.insert(
161+
(
188162
FieldIdx::from_usize(child_field_idx + num_args),
189163
(
190164
FieldIdx::from_usize(parent_field_idx + num_args),
191165
parent_capture_ty,
192166
needs_deref,
193167
child_precise_captures,
194168
),
195-
);
196-
197-
field_used_at_least_once = true;
198-
}
199-
200-
// Make sure the field was used at least once.
201-
assert!(
202-
field_used_at_least_once,
203-
"we captured {parent_capture:#?} but it was not used in the child coroutine?"
204-
);
205-
}
206-
assert_eq!(child_captures.next(), None, "leftover child captures?");
169+
)
170+
},
171+
)
172+
.collect();
207173

208174
if coroutine_kind == ty::ClosureKind::FnOnce {
209175
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
@@ -233,23 +199,6 @@ impl<'tcx> MirPass<'tcx> for ByMoveBody {
233199
}
234200
}
235201

236-
fn child_prefix_matches_parent_projections(
237-
parent_capture: &ty::CapturedPlace<'_>,
238-
child_capture: &ty::CapturedPlace<'_>,
239-
) -> bool {
240-
let PlaceBase::Upvar(parent_base) = parent_capture.place.base else {
241-
bug!("expected capture to be an upvar");
242-
};
243-
let PlaceBase::Upvar(child_base) = child_capture.place.base else {
244-
bug!("expected capture to be an upvar");
245-
};
246-
247-
assert!(child_capture.place.projections.len() >= parent_capture.place.projections.len());
248-
parent_base.var_path.hir_id == child_base.var_path.hir_id
249-
&& std::iter::zip(&child_capture.place.projections, &parent_capture.place.projections)
250-
.all(|(child, parent)| child.kind == parent.kind)
251-
}
252-
253202
struct MakeByMoveBody<'tcx> {
254203
tcx: TyCtxt<'tcx>,
255204
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, &'tcx [Projection<'tcx>])>,

0 commit comments

Comments
 (0)