Skip to content

Commit d5d7204

Browse files
authored
refactor for easier maintainance (#160)
1 parent 1f83693 commit d5d7204

File tree

2 files changed

+317
-304
lines changed

2 files changed

+317
-304
lines changed

compiler/rustc_middle/src/ty/mod.rs

+2-304
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#![allow(rustc::usage_of_ty_tykind)]
1313
#![allow(unused_imports)]
1414

15-
use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree};
1615
use rustc_target::abi::FieldsShape;
1716

1817
pub use self::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable};
@@ -75,6 +74,7 @@ pub use rustc_type_ir::ConstKind::{
7574
};
7675
pub use rustc_type_ir::*;
7776

77+
pub use self::typetree::*;
7878
pub use self::binding::BindingMode;
7979
pub use self::binding::BindingMode::*;
8080
pub use self::closure::{
@@ -127,6 +127,7 @@ pub mod util;
127127
pub mod visit;
128128
pub mod vtable;
129129
pub mod walk;
130+
pub mod typetree;
130131

131132
mod adt;
132133
mod assoc;
@@ -2721,306 +2722,3 @@ mod size_asserts {
27212722
// tidy-alphabetical-end
27222723
}
27232724

2724-
pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
2725-
let mut visited = vec![];
2726-
let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None);
2727-
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty };
2728-
return TypeTree(vec![tt]);
2729-
}
2730-
2731-
use rustc_ast::expand::autodiff_attrs::DiffActivity;
2732-
2733-
// This function combines three tasks. To avoid traversing each type 3x, we combine them.
2734-
// 1. Create a TypeTree from a Ty. This is the main task.
2735-
// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM
2736-
// lowering. E.g. fat ptr are going to introduce an extra int.
2737-
// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an
2738-
// autodiff macro on top). Here we want to make sure that shadows are mutable internally.
2739-
// We know the outermost ref/ptr indirection is mutability - we generate it like that.
2740-
// We now have to make sure that inner ptr/ref are mutable too, or issue a warning.
2741-
// Not an error, becaues it only causes issues if they are actually read, which we don't check
2742-
// yet. We should add such analysis to relibably either issue an error or accept without warning.
2743-
// If there only were some reasearch to do that...
2744-
pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree {
2745-
if !fn_ty.is_fn() {
2746-
return FncTree { args: vec![], ret: TypeTree::new() };
2747-
}
2748-
let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
2749-
2750-
// If rustc compiles the unmodified primal, we know that this copy of the function
2751-
// also has correct lifetimes. We know that Enzyme won't free the shadow too early
2752-
// (or actually at all), so let's strip lifetimes when computing the layout.
2753-
// Recommended by compiler-errors:
2754-
// https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751
2755-
let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
2756-
2757-
let mut new_activities = vec![];
2758-
let mut new_positions = vec![];
2759-
let mut visited = vec![];
2760-
let mut args = vec![];
2761-
for (i, ty) in x.inputs().iter().enumerate() {
2762-
// We care about safety checks, if an argument get's duplicated and we write into the
2763-
// shadow. That's equivalent to Duplicated or DuplicatedOnly.
2764-
let safety = if !da.is_empty() {
2765-
assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len());
2766-
// If we have Activities, we also have spans
2767-
assert!(span.is_some());
2768-
match da[i] {
2769-
DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true,
2770-
_ => false,
2771-
}
2772-
} else {
2773-
false
2774-
};
2775-
2776-
visited.clear();
2777-
if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() {
2778-
if ty.is_fn_ptr() {
2779-
unimplemented!("what to do whith fn ptr?");
2780-
}
2781-
let inner_ty = ty.builtin_deref(true).unwrap().ty;
2782-
if inner_ty.is_slice() {
2783-
// We know that the lenght will be passed as extra arg.
2784-
let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span);
2785-
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
2786-
args.push(TypeTree(vec![tt]));
2787-
let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() };
2788-
args.push(TypeTree(vec![i64_tt]));
2789-
if !da.is_empty() {
2790-
// We are looking at a slice. The length of that slice will become an
2791-
// extra integer on llvm level. Integers are always const.
2792-
// However, if the slice get's duplicated, we want to know to later check the
2793-
// size. So we mark the new size argument as FakeActivitySize.
2794-
let activity = match da[i] {
2795-
DiffActivity::DualOnly | DiffActivity::Dual |
2796-
DiffActivity::DuplicatedOnly | DiffActivity::Duplicated
2797-
=> DiffActivity::FakeActivitySize,
2798-
DiffActivity::Const => DiffActivity::Const,
2799-
_ => panic!("unexpected activity for ptr/ref"),
2800-
};
2801-
new_activities.push(activity);
2802-
new_positions.push(i + 1);
2803-
}
2804-
trace!("ABI MATCHING!");
2805-
continue;
2806-
}
2807-
}
2808-
let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span);
2809-
args.push(arg_tt);
2810-
}
2811-
2812-
// now add the extra activities coming from slices
2813-
// Reverse order to not invalidate the indices
2814-
for _ in 0..new_activities.len() {
2815-
let pos = new_positions.pop().unwrap();
2816-
let activity = new_activities.pop().unwrap();
2817-
da.insert(pos, activity);
2818-
}
2819-
2820-
visited.clear();
2821-
let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span);
2822-
2823-
FncTree { args, ret }
2824-
}
2825-
2826-
fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec<Ty<'a>>, span: Option<Span>) -> TypeTree {
2827-
if depth > 20 {
2828-
trace!("depth > 20 for ty: {}", &ty);
2829-
}
2830-
if visited.contains(&ty) {
2831-
// recursive type
2832-
trace!("recursive type: {}", &ty);
2833-
return TypeTree::new();
2834-
}
2835-
visited.push(ty);
2836-
2837-
if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() {
2838-
if ty.is_fn_ptr() {
2839-
unimplemented!("what to do whith fn ptr?");
2840-
}
2841-
2842-
let inner_ty_and_mut = ty.builtin_deref(true).unwrap();
2843-
let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut;
2844-
let inner_ty = inner_ty_and_mut.ty;
2845-
2846-
// Now account for inner mutability.
2847-
if !is_mut && depth > 0 && safety {
2848-
let ptr_ty: String = if ty.is_ref() {
2849-
"ref"
2850-
} else if ty.is_unsafe_ptr() {
2851-
"ptr"
2852-
} else {
2853-
assert!(ty.is_box());
2854-
"box"
2855-
}.to_string();
2856-
2857-
// If we have mutability, we also have a span
2858-
assert!(span.is_some());
2859-
let span = span.unwrap();
2860-
2861-
tcx.sess
2862-
.dcx()
2863-
.emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty});
2864-
}
2865-
2866-
//visited.push(inner_ty);
2867-
let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span);
2868-
let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child };
2869-
visited.pop();
2870-
return TypeTree(vec![tt]);
2871-
}
2872-
2873-
2874-
if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() {
2875-
visited.pop();
2876-
return TypeTree::new();
2877-
}
2878-
2879-
if ty.is_scalar() {
2880-
let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() {
2881-
(Kind::Integer, ty.primitive_size(tcx).bytes_usize())
2882-
} else if ty.is_floating_point() {
2883-
match ty {
2884-
x if x == tcx.types.f32 => (Kind::Float, 4),
2885-
x if x == tcx.types.f64 => (Kind::Double, 8),
2886-
_ => panic!("floatTy scalar that is neither f32 nor f64"),
2887-
}
2888-
} else {
2889-
panic!("scalar that is neither integral nor floating point");
2890-
};
2891-
visited.pop();
2892-
return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]);
2893-
}
2894-
2895-
let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty };
2896-
2897-
let layout = tcx.layout_of(param_env_and);
2898-
assert!(layout.is_ok());
2899-
2900-
let layout = layout.unwrap().layout;
2901-
let fields = layout.fields();
2902-
let max_size = layout.size();
2903-
2904-
2905-
2906-
if ty.is_adt() && !ty.is_simd() {
2907-
let adt_def = ty.ty_adt_def().unwrap();
2908-
2909-
if adt_def.is_struct() {
2910-
let (offsets, _memory_index) = match fields {
2911-
// Manuel TODO:
2912-
FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m),
2913-
FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later
2914-
FieldsShape::Union(_) => {return TypeTree::new();},
2915-
FieldsShape::Primitive => {return TypeTree::new();},
2916-
};
2917-
2918-
let substs = match ty.kind() {
2919-
Adt(_, subst_ref) => subst_ref,
2920-
_ => panic!(""),
2921-
};
2922-
2923-
let fields = adt_def.all_fields();
2924-
let fields = fields
2925-
.into_iter()
2926-
.zip(offsets.into_iter())
2927-
.filter_map(|(field, offset)| {
2928-
let field_ty: Ty<'_> = field.ty(tcx, substs);
2929-
let field_ty: Ty<'_> =
2930-
tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty);
2931-
2932-
if field_ty.is_phantom_data() {
2933-
return None;
2934-
}
2935-
2936-
//visited.push(field_ty);
2937-
let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0;
2938-
2939-
for c in &mut child {
2940-
if c.offset == -1 {
2941-
c.offset = offset.bytes() as isize
2942-
} else {
2943-
c.offset += offset.bytes() as isize;
2944-
}
2945-
}
2946-
2947-
Some(child)
2948-
})
2949-
.flatten()
2950-
.collect::<Vec<Type>>();
2951-
2952-
visited.pop();
2953-
let ret_tt = TypeTree(fields);
2954-
return ret_tt;
2955-
} else if adt_def.is_enum() {
2956-
// Enzyme can't represent enums, so let it figure it out itself, without seeeding
2957-
// typetree
2958-
//unimplemented!("adt that is an enum");
2959-
} else {
2960-
//let ty_name = tcx.def_path_debug_str(adt_def.did());
2961-
//tcx.sess.emit_fatal(UnsupportedUnion { ty_name });
2962-
}
2963-
}
2964-
2965-
if ty.is_simd() {
2966-
trace!("simd");
2967-
let (_size, inner_ty) = ty.simd_size_and_type(tcx);
2968-
//visited.push(inner_ty);
2969-
let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span);
2970-
//let tt = TypeTree(
2971-
// std::iter::repeat(subtt)
2972-
// .take(*count as usize)
2973-
// .enumerate()
2974-
// .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize)))
2975-
// .flatten()
2976-
// .collect(),
2977-
//);
2978-
// TODO
2979-
visited.pop();
2980-
return TypeTree::new();
2981-
}
2982-
2983-
if ty.is_array() {
2984-
let (stride, count) = match fields {
2985-
FieldsShape::Array { stride: s, count: c } => (s, c),
2986-
_ => panic!(""),
2987-
};
2988-
let byte_stride = stride.bytes_usize();
2989-
let byte_max_size = max_size.bytes_usize();
2990-
2991-
assert!(byte_stride * *count as usize == byte_max_size);
2992-
if (*count as usize) == 0 {
2993-
return TypeTree::new();
2994-
}
2995-
let sub_ty = ty.builtin_index().unwrap();
2996-
//visited.push(sub_ty);
2997-
let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span);
2998-
2999-
// calculate size of subtree
3000-
let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty };
3001-
let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize;
3002-
let tt = TypeTree(
3003-
std::iter::repeat(subtt)
3004-
.take(*count as usize)
3005-
.enumerate()
3006-
.map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize)))
3007-
.flatten()
3008-
.collect(),
3009-
);
3010-
3011-
visited.pop();
3012-
return tt;
3013-
}
3014-
3015-
if ty.is_slice() {
3016-
let sub_ty = ty.builtin_index().unwrap();
3017-
//visited.push(sub_ty);
3018-
let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span);
3019-
3020-
visited.pop();
3021-
return subtt;
3022-
}
3023-
3024-
visited.pop();
3025-
TypeTree::new()
3026-
}

0 commit comments

Comments
 (0)