Skip to content

Commit 2b3a9be

Browse files
authored
Refactor coercion logic for future usage inside codegen (rust-lang#1837)
We currently have a few issues with how we are generating code for casting (rust-lang#566, and rust-lang#1528). The structure of the code is also hard to understand and maintain (see rust-lang#1531 for more details). This PR is the first part of the fix I developed. This change moves the coercion specific code to its own module and it introduces an iterator that traverses the coercion path.
1 parent fba67da commit 2b3a9be

File tree

3 files changed

+246
-112
lines changed

3 files changed

+246
-112
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
4+
//! This module contains methods that help us process coercions.
5+
//! There are many types of coercions in rust, they are described in the
6+
//! [RFC 401 Coercions](https://rust-lang.github.io/rfcs/0401-coercions.html).
7+
//!
8+
//! The more complicated coercions are DST Coercions (aka Unsized Coercions). These coercions
9+
//! allow rust to create references to dynamically sized types (such as Traits and Slices) by
10+
//! casting concrete sized types. Unsized coercions can also be used to cast unsized to unsized
11+
//! types. These casts work not only on the top of references, but it also handle
12+
//! references inside structures, allowing the unsized coercions of smart pointers. The
13+
//! definition of custom coercions for smart pointers can be found in the
14+
//! [RFC 982 DST Coercion](https://rust-lang.github.io/rfcs/0982-dst-coercion.html).
15+
16+
use rustc_hir::lang_items::LangItem;
17+
use rustc_middle::traits::{ImplSource, ImplSourceUserDefinedData};
18+
use rustc_middle::ty::adjustment::CustomCoerceUnsized;
19+
use rustc_middle::ty::TypeAndMut;
20+
use rustc_middle::ty::{self, ParamEnv, TraitRef, Ty, TyCtxt};
21+
use rustc_span::symbol::Symbol;
22+
use tracing::trace;
23+
24+
/// Given an unsized coercion (e.g. from `&u8` to `&dyn Debug`), extract the pair of
25+
/// corresponding base types `T`, `U` (e.g. `u8`, `dyn Debug`), where the source base type `T` must
26+
/// implement `Unsize<U>` and `U` is either a trait or slice.
27+
///
28+
/// For more details, please refer to:
29+
/// <https://doc.rust-lang.org/reference/type-coercions.html#unsized-coercions>
30+
///
31+
/// This is used to determine the vtable implementation that must be tracked by the fat pointer.
32+
///
33+
/// For example, if `&u8` is being converted to `&dyn Debug`, this method would return:
34+
/// `(u8, dyn Debug)`.
35+
///
36+
/// There are a few interesting cases (references / pointers are handled the same way):
37+
/// 1. Coercion between `&T` to `&U`.
38+
/// - This is the base case.
39+
/// - In this case, we extract the types that are pointed to.
40+
/// 2. Coercion between smart pointers like `NonNull<T>` to `NonNull<U>`.
41+
/// - Smart pointers implement the `CoerceUnsize` trait.
42+
/// - Use CustomCoerceUnsized information to traverse the smart pointer structure and find the
43+
/// underlying pointer.
44+
/// - Use base case to extract `T` and `U`.
45+
/// 3. Coercion between a pointer to a structure whose tail is being coerced.
46+
///
47+
/// E.g.: A user may want to define a type like:
48+
/// ```
49+
/// struct Message<T> {
50+
/// header: &str,
51+
/// content: T,
52+
/// }
53+
/// ```
54+
/// They may want to abstract only the content of a message. So one could coerce a
55+
/// `&Message<String>` into a `&Message<dyn Display>`. In this case, this would:
56+
/// - Apply base case to extract the pair `(Message<T>, Message<U>)`.
57+
/// - Extract the tail element of the struct which are of type `T` and `U`, respectively.
58+
/// 4. Coercion between smart pointers of wrapper structs.
59+
/// - Apply the logic from item 2 then item 3.
60+
pub fn extract_unsize_casting<'tcx>(
61+
tcx: TyCtxt<'tcx>,
62+
src_ty: Ty<'tcx>,
63+
dst_ty: Ty<'tcx>,
64+
) -> CoercionBase<'tcx> {
65+
trace!(?src_ty, ?dst_ty, "extract_unsize_casting");
66+
// Iterate over the pointer structure to find the builtin pointer that will store the metadata.
67+
let coerce_info = CoerceUnsizedIterator::new(tcx, src_ty, dst_ty).last().unwrap();
68+
// Extract the pointee type that is being coerced.
69+
let src_pointee_ty = extract_pointee(coerce_info.src_ty).expect(&format!(
70+
"Expected source to be a pointer. Found {:?} instead",
71+
coerce_info.src_ty
72+
));
73+
let dst_pointee_ty = extract_pointee(coerce_info.dst_ty).expect(&format!(
74+
"Expected destination to be a pointer. Found {:?} instead",
75+
coerce_info.dst_ty
76+
));
77+
// Find the tail of the coercion that determines the type of metadata to be stored.
78+
let (src_base_ty, dst_base_ty) = tcx.struct_lockstep_tails_erasing_lifetimes(
79+
src_pointee_ty,
80+
dst_pointee_ty,
81+
ParamEnv::reveal_all(),
82+
);
83+
trace!(?src_base_ty, ?dst_base_ty, "extract_unsize_casting result");
84+
assert!(
85+
dst_base_ty.is_trait() || dst_base_ty.is_slice(),
86+
"Expected trait or slice as destination of unsized cast, but found {dst_base_ty:?}"
87+
);
88+
CoercionBase { src_ty: src_base_ty, dst_ty: dst_base_ty }
89+
}
90+
91+
/// This structure represents the base of a coercion.
92+
///
93+
/// This base is used to determine the information that will be stored in the metadata.
94+
/// E.g.: In order to convert an `Rc<String>` into an `Rc<dyn Debug>`, we need to generate a
95+
/// vtable that represents the `impl Debug for String`. So this type will carry the `String` type
96+
/// as the `src_ty` and the `dyn Debug` trait as `dst_ty`.
97+
#[derive(Debug)]
98+
pub struct CoercionBase<'tcx> {
99+
pub src_ty: Ty<'tcx>,
100+
pub dst_ty: Ty<'tcx>,
101+
}
102+
103+
/// Iterates over the coercion path of a structure that implements `CoerceUnsized<T>` trait.
104+
/// The `CoerceUnsized<T>` trait indicates that this is a pointer or a wrapper for one, where
105+
/// unsizing can be performed on the pointee. More details:
106+
/// <https://doc.rust-lang.org/std/ops/trait.CoerceUnsized.html>
107+
///
108+
/// Given an unsized coercion between `impl CoerceUnsized<T>` to `impl CoerceUnsized<U>` where
109+
/// `T` is sized and `U` is unsized, this iterator will walk over the fields that lead to a
110+
/// pointer to `T`, which shall be converted from a thin pointer to a fat pointer.
111+
///
112+
/// Each iteration will also include an optional name of the field that differs from the current
113+
/// pair of types.
114+
///
115+
/// The first element of the iteration will always be the starting types.
116+
/// The last element of the iteration will always be pointers to `T` and `U`.
117+
/// After unsized element has been found, the iterator will return `None`.
118+
pub struct CoerceUnsizedIterator<'tcx> {
119+
tcx: TyCtxt<'tcx>,
120+
src_ty: Option<Ty<'tcx>>,
121+
dst_ty: Option<Ty<'tcx>>,
122+
}
123+
124+
/// Represent the information about a coercion.
125+
#[derive(Debug, Clone, PartialOrd, PartialEq)]
126+
pub struct CoerceUnsizedInfo<'tcx> {
127+
/// The name of the field from the current types that differs between each other.
128+
pub field: Option<Symbol>,
129+
/// The type being coerced.
130+
pub src_ty: Ty<'tcx>,
131+
/// The type that is the result of the coercion.
132+
pub dst_ty: Ty<'tcx>,
133+
}
134+
135+
impl<'tcx> CoerceUnsizedIterator<'tcx> {
136+
pub fn new(
137+
tcx: TyCtxt<'tcx>,
138+
src_ty: Ty<'tcx>,
139+
dst_ty: Ty<'tcx>,
140+
) -> CoerceUnsizedIterator<'tcx> {
141+
CoerceUnsizedIterator { tcx, src_ty: Some(src_ty), dst_ty: Some(dst_ty) }
142+
}
143+
}
144+
145+
/// Iterate over the coercion path. At each iteration, it returns the name of the field that must
146+
/// be coerced, as well as the current source and the destination.
147+
/// E.g.: The first iteration of casting `NonNull<String>` -> `NonNull<&dyn Debug>` will return
148+
/// ```rust,ignore
149+
/// CoerceUnsizedInfo {
150+
/// field: Some("ptr"),
151+
/// src_ty, // NonNull<String>
152+
/// dst_ty // NonNull<&dyn Debug>
153+
/// }
154+
/// ```
155+
/// while the last iteration will return:
156+
/// ```rust,ignore
157+
/// CoerceUnsizedInfo {
158+
/// field: None,
159+
/// src_ty: Ty, // *const String
160+
/// dst_ty: Ty, // *const &dyn Debug
161+
/// }
162+
/// ```
163+
impl<'tcx> Iterator for CoerceUnsizedIterator<'tcx> {
164+
type Item = CoerceUnsizedInfo<'tcx>;
165+
166+
fn next(&mut self) -> Option<Self::Item> {
167+
if self.src_ty.is_none() {
168+
assert_eq!(self.dst_ty, None, "Expected no dst type.");
169+
return None;
170+
}
171+
172+
// Extract the pointee types from pointers (including smart pointers) that form the base of
173+
// the conversion.
174+
let src_ty = self.src_ty.take().unwrap();
175+
let dst_ty = self.dst_ty.take().unwrap();
176+
let field = match (&src_ty.kind(), &dst_ty.kind()) {
177+
(&ty::Adt(src_def, src_substs), &ty::Adt(dst_def, dst_substs)) => {
178+
// Handle smart pointers by using CustomCoerceUnsized to find the field being
179+
// coerced.
180+
assert_eq!(src_def, dst_def);
181+
let src_fields = &src_def.non_enum_variant().fields;
182+
let dst_fields = &dst_def.non_enum_variant().fields;
183+
assert_eq!(src_fields.len(), dst_fields.len());
184+
185+
let CustomCoerceUnsized::Struct(coerce_index) =
186+
custom_coerce_unsize_info(self.tcx, src_ty, dst_ty);
187+
assert!(coerce_index < src_fields.len());
188+
189+
self.src_ty = Some(src_fields[coerce_index].ty(self.tcx, src_substs));
190+
self.dst_ty = Some(dst_fields[coerce_index].ty(self.tcx, dst_substs));
191+
Some(src_fields[coerce_index].name)
192+
}
193+
_ => {
194+
// Base case is always a pointer (Box, raw_pointer or reference).
195+
assert!(
196+
extract_pointee(src_ty).is_some(),
197+
"Expected a pointer, but found {src_ty:?}"
198+
);
199+
None
200+
}
201+
};
202+
Some(CoerceUnsizedInfo { field, src_ty, dst_ty })
203+
}
204+
}
205+
206+
/// Get information about an unsized coercion.
207+
/// This code was extracted from `rustc_monomorphize` crate.
208+
/// <https://github.com/rust-lang/rust/blob/4891d57f7aab37b5d6a84f2901c0bb8903111d53/compiler/rustc_monomorphize/src/lib.rs#L25-L46>
209+
fn custom_coerce_unsize_info<'tcx>(
210+
tcx: TyCtxt<'tcx>,
211+
source_ty: Ty<'tcx>,
212+
target_ty: Ty<'tcx>,
213+
) -> CustomCoerceUnsized {
214+
let def_id = tcx.require_lang_item(LangItem::CoerceUnsized, None);
215+
216+
let trait_ref = ty::Binder::dummy(TraitRef {
217+
def_id,
218+
substs: tcx.mk_substs_trait(source_ty, &[target_ty.into()]),
219+
});
220+
221+
match tcx.codegen_select_candidate((ParamEnv::reveal_all(), trait_ref)) {
222+
Ok(ImplSource::UserDefined(ImplSourceUserDefinedData { impl_def_id, .. })) => {
223+
tcx.coerce_unsized_info(impl_def_id).custom_kind.unwrap()
224+
}
225+
impl_source => {
226+
unreachable!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
227+
}
228+
}
229+
}
230+
231+
/// Extract pointee type from builtin pointer types.
232+
fn extract_pointee(typ: Ty) -> Option<Ty> {
233+
typ.builtin_deref(true).map(|TypeAndMut { ty, .. }| ty)
234+
}

kani-compiler/src/kani_middle/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
//! This module contains code that are backend agnostic. For example, MIR analysis
44
//! and transformations.
55
pub mod attributes;
6+
pub mod coercion;
67
pub mod reachability;

kani-compiler/src/kani_middle/reachability.rs

Lines changed: 11 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,26 @@
1313
//! - For every static, collect initializer and drop functions.
1414
//!
1515
//! We have kept this module agnostic of any Kani code in case we can contribute this back to rustc.
16+
use tracing::{debug, debug_span, trace, warn};
17+
1618
use rustc_data_structures::fingerprint::Fingerprint;
1719
use rustc_data_structures::fx::FxHashSet;
1820
use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
19-
use rustc_hir::lang_items::LangItem;
2021
use rustc_middle::mir::interpret::{AllocId, ConstValue, ErrorHandled, GlobalAlloc, Scalar};
2122
use rustc_middle::mir::mono::MonoItem;
2223
use rustc_middle::mir::visit::Visitor as MirVisitor;
2324
use rustc_middle::mir::{
2425
Body, CastKind, Constant, ConstantKind, Location, Rvalue, Terminator, TerminatorKind,
2526
};
2627
use rustc_middle::span_bug;
27-
use rustc_middle::traits::{ImplSource, ImplSourceUserDefinedData};
28-
use rustc_middle::ty::adjustment::CustomCoerceUnsized;
2928
use rustc_middle::ty::adjustment::PointerCast;
30-
use rustc_middle::ty::TypeAndMut;
3129
use rustc_middle::ty::{
32-
self, Closure, ClosureKind, ConstKind, Instance, InstanceDef, ParamEnv, TraitRef, Ty, TyCtxt,
33-
TyKind, TypeFoldable, VtblEntry,
30+
Closure, ClosureKind, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt, TyKind,
31+
TypeFoldable, VtblEntry,
3432
};
3533
use rustc_span::def_id::DefId;
36-
use tracing::{debug, debug_span, trace, warn};
34+
35+
use crate::kani_middle::coercion;
3736

3837
/// Collect all reachable items starting from the given starting points.
3938
pub fn collect_reachable_items<'tcx>(
@@ -275,10 +274,11 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MonoItemsFnCollector<'a, 'tcx> {
275274
// If so, collect items from the impl `Trait for Concrete {}`.
276275
let target_ty = self.monomorphize(target);
277276
let source_ty = self.monomorphize(operand.ty(self.body, self.tcx));
278-
let (src_inner, dst_inner) = extract_trait_casting(self.tcx, source_ty, target_ty);
279-
if !src_inner.is_trait() && dst_inner.is_trait() {
280-
debug!(concrete_ty=?src_inner, trait_ty=?dst_inner, "collect_vtable_methods");
281-
self.collect_vtable_methods(src_inner, dst_inner);
277+
let base_coercion =
278+
coercion::extract_unsize_casting(self.tcx, source_ty, target_ty);
279+
if !base_coercion.src_ty.is_trait() && base_coercion.dst_ty.is_trait() {
280+
debug!(?base_coercion, "collect_vtable_methods");
281+
self.collect_vtable_methods(base_coercion.src_ty, base_coercion.dst_ty);
282282
}
283283
}
284284
Rvalue::Cast(CastKind::Pointer(PointerCast::ReifyFnPointer), ref operand, _) => {
@@ -449,107 +449,6 @@ fn should_codegen_locally<'tcx>(tcx: TyCtxt<'tcx>, instance: &Instance<'tcx>) ->
449449
}
450450
}
451451

452-
/// Extract the pair (`T`, `U`) for a unsized coercion where type `T` implements `Unsize<U>`.
453-
/// I.e., `U` is either a trait or a slice.
454-
/// For more details, please refer to:
455-
/// <https://doc.rust-lang.org/reference/type-coercions.html#unsized-coercions>
456-
///
457-
/// This is used to determine the vtable implementation that must be tracked by the fat pointer.
458-
///
459-
/// For example, if `&u8` is being converted to `&dyn Debug`, this method would return:
460-
/// `(u8, dyn Debug)`.
461-
///
462-
/// There are a few interesting cases (references / pointers are handled the same way):
463-
/// 1. Coercion between `&T` to `&U`.
464-
/// - This is the base case.
465-
/// - In this case, we extract the type that is pointed to.
466-
/// 2. Coercion between smart pointers like `Rc<T>` to `Rc<U>`.
467-
/// - Smart pointers implement the `CoerceUnsize` trait.
468-
/// - Use CustomCoerceUnsized information to traverse the smart pointer structure and find the
469-
/// underlying pointer.
470-
/// - Use base case to extract `T` and `U`.
471-
/// 3. Coercion between `&Wrapper<T>` to `&Wrapper<U>`.
472-
/// - Apply base case to extract the pair `(Wrapper<T>, Wrapper<U>)`.
473-
/// - Extract the tail element of the struct which are of type `T` and `U`, respectively.
474-
/// 4. Coercion between smart pointers of wrapper structs.
475-
/// - Apply the logic from item 2 then item 3.
476-
fn extract_trait_casting<'tcx>(
477-
tcx: TyCtxt<'tcx>,
478-
src_ty: Ty<'tcx>,
479-
dst_ty: Ty<'tcx>,
480-
) -> (Ty<'tcx>, Ty<'tcx>) {
481-
trace!(?dst_ty, ?src_ty, "find_trait_conversion");
482-
let mut src_inner_ty = src_ty;
483-
let mut dst_inner_ty = dst_ty;
484-
(src_inner_ty, dst_inner_ty) = loop {
485-
// Extract the pointee types from pointers (including smart pointers) that form the base of
486-
// the conversion.
487-
match (&src_inner_ty.kind(), &dst_inner_ty.kind()) {
488-
(&ty::Adt(src_def, src_substs), &ty::Adt(dst_def, dst_substs))
489-
if !src_def.is_box() || !dst_def.is_box() =>
490-
{
491-
// Handle smart pointers by using CustomCoerceUnsized to find the field being
492-
// coerced.
493-
assert_eq!(src_def, dst_def);
494-
let src_fields = &src_def.non_enum_variant().fields;
495-
let dst_fields = &dst_def.non_enum_variant().fields;
496-
assert_eq!(src_fields.len(), dst_fields.len());
497-
498-
let CustomCoerceUnsized::Struct(coerce_index) =
499-
custom_coerce_unsize_info(tcx, src_inner_ty, dst_inner_ty);
500-
assert!(coerce_index < src_fields.len());
501-
502-
src_inner_ty = src_fields[coerce_index].ty(tcx, src_substs);
503-
dst_inner_ty = dst_fields[coerce_index].ty(tcx, dst_substs);
504-
}
505-
_ => {
506-
// Base case is always a pointer (Box, raw_pointer or reference).
507-
let src_pointee = extract_pointee(src_inner_ty).expect(&format!(
508-
"Expected source to be a pointer. Found {:?} instead",
509-
src_inner_ty
510-
));
511-
let dst_pointee = extract_pointee(dst_inner_ty).expect(&format!(
512-
"Expected destination to be a pointer. Found {:?} instead",
513-
dst_inner_ty
514-
));
515-
break (src_pointee, dst_pointee);
516-
}
517-
}
518-
};
519-
520-
tcx.struct_lockstep_tails_erasing_lifetimes(src_inner_ty, dst_inner_ty, ParamEnv::reveal_all())
521-
}
522-
523-
/// Extract pointee type from builtin pointer types.
524-
fn extract_pointee(typ: Ty) -> Option<Ty> {
525-
typ.builtin_deref(true).map(|TypeAndMut { ty, .. }| ty)
526-
}
527-
528-
/// Get information about an unsized coercion.
529-
/// This code was extracted from `rustc_monomorphize` crate.
530-
/// <https://github.com/rust-lang/rust/blob/4891d57f7aab37b5d6a84f2901c0bb8903111d53/compiler/rustc_monomorphize/src/lib.rs#L25-L46>
531-
fn custom_coerce_unsize_info<'tcx>(
532-
tcx: TyCtxt<'tcx>,
533-
source_ty: Ty<'tcx>,
534-
target_ty: Ty<'tcx>,
535-
) -> CustomCoerceUnsized {
536-
let def_id = tcx.require_lang_item(LangItem::CoerceUnsized, None);
537-
538-
let trait_ref = ty::Binder::dummy(TraitRef {
539-
def_id,
540-
substs: tcx.mk_substs_trait(source_ty, &[target_ty.into()]),
541-
});
542-
543-
match tcx.codegen_select_candidate((ParamEnv::reveal_all(), trait_ref)) {
544-
Ok(ImplSource::UserDefined(ImplSourceUserDefinedData { impl_def_id, .. })) => {
545-
tcx.coerce_unsized_info(impl_def_id).custom_kind.unwrap()
546-
}
547-
impl_source => {
548-
unreachable!("invalid `CoerceUnsized` impl_source: {:?}", impl_source);
549-
}
550-
}
551-
}
552-
553452
/// Scans the allocation type and collect static objects.
554453
fn collect_alloc_items<'tcx>(tcx: TyCtxt<'tcx>, alloc_id: AllocId) -> Vec<MonoItem> {
555454
trace!(alloc=?tcx.global_alloc(alloc_id), ?alloc_id, "collect_alloc_items");

0 commit comments

Comments
 (0)