|
12 | 12 | #![allow(rustc::usage_of_ty_tykind)]
|
13 | 13 | #![allow(unused_imports)]
|
14 | 14 |
|
15 |
| -use rustc_ast::expand::typetree::{Type, Kind, TypeTree, FncTree}; |
16 | 15 | use rustc_target::abi::FieldsShape;
|
17 | 16 |
|
18 | 17 | pub use self::fold::{FallibleTypeFolder, TypeFoldable, TypeFolder, TypeSuperFoldable};
|
@@ -75,6 +74,7 @@ pub use rustc_type_ir::ConstKind::{
|
75 | 74 | };
|
76 | 75 | pub use rustc_type_ir::*;
|
77 | 76 |
|
| 77 | +pub use self::typetree::*; |
78 | 78 | pub use self::binding::BindingMode;
|
79 | 79 | pub use self::binding::BindingMode::*;
|
80 | 80 | pub use self::closure::{
|
@@ -127,6 +127,7 @@ pub mod util;
|
127 | 127 | pub mod visit;
|
128 | 128 | pub mod vtable;
|
129 | 129 | pub mod walk;
|
| 130 | +pub mod typetree; |
130 | 131 |
|
131 | 132 | mod adt;
|
132 | 133 | mod assoc;
|
@@ -2721,306 +2722,3 @@ mod size_asserts {
|
2721 | 2722 | // tidy-alphabetical-end
|
2722 | 2723 | }
|
2723 | 2724 |
|
2724 |
| -pub fn typetree_from<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree { |
2725 |
| - let mut visited = vec![]; |
2726 |
| - let ty = typetree_from_ty(ty, tcx, 0, false, &mut visited, None); |
2727 |
| - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child: ty }; |
2728 |
| - return TypeTree(vec![tt]); |
2729 |
| -} |
2730 |
| - |
2731 |
| -use rustc_ast::expand::autodiff_attrs::DiffActivity; |
2732 |
| - |
2733 |
| -// This function combines three tasks. To avoid traversing each type 3x, we combine them. |
2734 |
| -// 1. Create a TypeTree from a Ty. This is the main task. |
2735 |
| -// 2. IFF da is not empty, we also want to adjust DiffActivity to account for future MIR->LLVM |
2736 |
| -// lowering. E.g. fat ptr are going to introduce an extra int. |
2737 |
| -// 3. IFF da is not empty, we are creating TT for a function directly differentiated (has an |
2738 |
| -// autodiff macro on top). Here we want to make sure that shadows are mutable internally. |
2739 |
| -// We know the outermost ref/ptr indirection is mutability - we generate it like that. |
2740 |
| -// We now have to make sure that inner ptr/ref are mutable too, or issue a warning. |
2741 |
| -// Not an error, becaues it only causes issues if they are actually read, which we don't check |
2742 |
| -// yet. We should add such analysis to relibably either issue an error or accept without warning. |
2743 |
| -// If there only were some reasearch to do that... |
2744 |
| -pub fn fnc_typetrees<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>, span: Option<Span>) -> FncTree { |
2745 |
| - if !fn_ty.is_fn() { |
2746 |
| - return FncTree { args: vec![], ret: TypeTree::new() }; |
2747 |
| - } |
2748 |
| - let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); |
2749 |
| - |
2750 |
| - // If rustc compiles the unmodified primal, we know that this copy of the function |
2751 |
| - // also has correct lifetimes. We know that Enzyme won't free the shadow too early |
2752 |
| - // (or actually at all), so let's strip lifetimes when computing the layout. |
2753 |
| - // Recommended by compiler-errors: |
2754 |
| - // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 |
2755 |
| - let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); |
2756 |
| - |
2757 |
| - let mut new_activities = vec![]; |
2758 |
| - let mut new_positions = vec![]; |
2759 |
| - let mut visited = vec![]; |
2760 |
| - let mut args = vec![]; |
2761 |
| - for (i, ty) in x.inputs().iter().enumerate() { |
2762 |
| - // We care about safety checks, if an argument get's duplicated and we write into the |
2763 |
| - // shadow. That's equivalent to Duplicated or DuplicatedOnly. |
2764 |
| - let safety = if !da.is_empty() { |
2765 |
| - assert!(da.len() == x.inputs().len(), "{:?} != {:?}", da.len(), x.inputs().len()); |
2766 |
| - // If we have Activities, we also have spans |
2767 |
| - assert!(span.is_some()); |
2768 |
| - match da[i] { |
2769 |
| - DiffActivity::DuplicatedOnly | DiffActivity::Duplicated => true, |
2770 |
| - _ => false, |
2771 |
| - } |
2772 |
| - } else { |
2773 |
| - false |
2774 |
| - }; |
2775 |
| - |
2776 |
| - visited.clear(); |
2777 |
| - if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { |
2778 |
| - if ty.is_fn_ptr() { |
2779 |
| - unimplemented!("what to do whith fn ptr?"); |
2780 |
| - } |
2781 |
| - let inner_ty = ty.builtin_deref(true).unwrap().ty; |
2782 |
| - if inner_ty.is_slice() { |
2783 |
| - // We know that the lenght will be passed as extra arg. |
2784 |
| - let child = typetree_from_ty(inner_ty, tcx, 1, safety, &mut visited, span); |
2785 |
| - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; |
2786 |
| - args.push(TypeTree(vec![tt])); |
2787 |
| - let i64_tt = Type { offset: -1, kind: Kind::Integer, size: 8, child: TypeTree::new() }; |
2788 |
| - args.push(TypeTree(vec![i64_tt])); |
2789 |
| - if !da.is_empty() { |
2790 |
| - // We are looking at a slice. The length of that slice will become an |
2791 |
| - // extra integer on llvm level. Integers are always const. |
2792 |
| - // However, if the slice get's duplicated, we want to know to later check the |
2793 |
| - // size. So we mark the new size argument as FakeActivitySize. |
2794 |
| - let activity = match da[i] { |
2795 |
| - DiffActivity::DualOnly | DiffActivity::Dual | |
2796 |
| - DiffActivity::DuplicatedOnly | DiffActivity::Duplicated |
2797 |
| - => DiffActivity::FakeActivitySize, |
2798 |
| - DiffActivity::Const => DiffActivity::Const, |
2799 |
| - _ => panic!("unexpected activity for ptr/ref"), |
2800 |
| - }; |
2801 |
| - new_activities.push(activity); |
2802 |
| - new_positions.push(i + 1); |
2803 |
| - } |
2804 |
| - trace!("ABI MATCHING!"); |
2805 |
| - continue; |
2806 |
| - } |
2807 |
| - } |
2808 |
| - let arg_tt = typetree_from_ty(*ty, tcx, 0, safety, &mut visited, span); |
2809 |
| - args.push(arg_tt); |
2810 |
| - } |
2811 |
| - |
2812 |
| - // now add the extra activities coming from slices |
2813 |
| - // Reverse order to not invalidate the indices |
2814 |
| - for _ in 0..new_activities.len() { |
2815 |
| - let pos = new_positions.pop().unwrap(); |
2816 |
| - let activity = new_activities.pop().unwrap(); |
2817 |
| - da.insert(pos, activity); |
2818 |
| - } |
2819 |
| - |
2820 |
| - visited.clear(); |
2821 |
| - let ret = typetree_from_ty(x.output(), tcx, 0, false, &mut visited, span); |
2822 |
| - |
2823 |
| - FncTree { args, ret } |
2824 |
| -} |
2825 |
| - |
2826 |
| -fn typetree_from_ty<'a>(ty: Ty<'a>, tcx: TyCtxt<'a>, depth: usize, safety: bool, visited: &mut Vec<Ty<'a>>, span: Option<Span>) -> TypeTree { |
2827 |
| - if depth > 20 { |
2828 |
| - trace!("depth > 20 for ty: {}", &ty); |
2829 |
| - } |
2830 |
| - if visited.contains(&ty) { |
2831 |
| - // recursive type |
2832 |
| - trace!("recursive type: {}", &ty); |
2833 |
| - return TypeTree::new(); |
2834 |
| - } |
2835 |
| - visited.push(ty); |
2836 |
| - |
2837 |
| - if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { |
2838 |
| - if ty.is_fn_ptr() { |
2839 |
| - unimplemented!("what to do whith fn ptr?"); |
2840 |
| - } |
2841 |
| - |
2842 |
| - let inner_ty_and_mut = ty.builtin_deref(true).unwrap(); |
2843 |
| - let is_mut = inner_ty_and_mut.mutbl == hir::Mutability::Mut; |
2844 |
| - let inner_ty = inner_ty_and_mut.ty; |
2845 |
| - |
2846 |
| - // Now account for inner mutability. |
2847 |
| - if !is_mut && depth > 0 && safety { |
2848 |
| - let ptr_ty: String = if ty.is_ref() { |
2849 |
| - "ref" |
2850 |
| - } else if ty.is_unsafe_ptr() { |
2851 |
| - "ptr" |
2852 |
| - } else { |
2853 |
| - assert!(ty.is_box()); |
2854 |
| - "box" |
2855 |
| - }.to_string(); |
2856 |
| - |
2857 |
| - // If we have mutability, we also have a span |
2858 |
| - assert!(span.is_some()); |
2859 |
| - let span = span.unwrap(); |
2860 |
| - |
2861 |
| - tcx.sess |
2862 |
| - .dcx() |
2863 |
| - .emit_warning(AutodiffUnsafeInnerConstRef{span, ty: ptr_ty}); |
2864 |
| - } |
2865 |
| - |
2866 |
| - //visited.push(inner_ty); |
2867 |
| - let child = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); |
2868 |
| - let tt = Type { offset: -1, kind: Kind::Pointer, size: 8, child }; |
2869 |
| - visited.pop(); |
2870 |
| - return TypeTree(vec![tt]); |
2871 |
| - } |
2872 |
| - |
2873 |
| - |
2874 |
| - if ty.is_closure() || ty.is_coroutine() || ty.is_fresh() || ty.is_fn() { |
2875 |
| - visited.pop(); |
2876 |
| - return TypeTree::new(); |
2877 |
| - } |
2878 |
| - |
2879 |
| - if ty.is_scalar() { |
2880 |
| - let (kind, size) = if ty.is_integral() || ty.is_char() || ty.is_bool() { |
2881 |
| - (Kind::Integer, ty.primitive_size(tcx).bytes_usize()) |
2882 |
| - } else if ty.is_floating_point() { |
2883 |
| - match ty { |
2884 |
| - x if x == tcx.types.f32 => (Kind::Float, 4), |
2885 |
| - x if x == tcx.types.f64 => (Kind::Double, 8), |
2886 |
| - _ => panic!("floatTy scalar that is neither f32 nor f64"), |
2887 |
| - } |
2888 |
| - } else { |
2889 |
| - panic!("scalar that is neither integral nor floating point"); |
2890 |
| - }; |
2891 |
| - visited.pop(); |
2892 |
| - return TypeTree(vec![Type { offset: -1, child: TypeTree::new(), kind, size }]); |
2893 |
| - } |
2894 |
| - |
2895 |
| - let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: ty }; |
2896 |
| - |
2897 |
| - let layout = tcx.layout_of(param_env_and); |
2898 |
| - assert!(layout.is_ok()); |
2899 |
| - |
2900 |
| - let layout = layout.unwrap().layout; |
2901 |
| - let fields = layout.fields(); |
2902 |
| - let max_size = layout.size(); |
2903 |
| - |
2904 |
| - |
2905 |
| - |
2906 |
| - if ty.is_adt() && !ty.is_simd() { |
2907 |
| - let adt_def = ty.ty_adt_def().unwrap(); |
2908 |
| - |
2909 |
| - if adt_def.is_struct() { |
2910 |
| - let (offsets, _memory_index) = match fields { |
2911 |
| - // Manuel TODO: |
2912 |
| - FieldsShape::Arbitrary { offsets: o, memory_index: m } => (o, m), |
2913 |
| - FieldsShape::Array { .. } => {return TypeTree::new();}, //e.g. core::arch::x86_64::__m128i, TODO: later |
2914 |
| - FieldsShape::Union(_) => {return TypeTree::new();}, |
2915 |
| - FieldsShape::Primitive => {return TypeTree::new();}, |
2916 |
| - }; |
2917 |
| - |
2918 |
| - let substs = match ty.kind() { |
2919 |
| - Adt(_, subst_ref) => subst_ref, |
2920 |
| - _ => panic!(""), |
2921 |
| - }; |
2922 |
| - |
2923 |
| - let fields = adt_def.all_fields(); |
2924 |
| - let fields = fields |
2925 |
| - .into_iter() |
2926 |
| - .zip(offsets.into_iter()) |
2927 |
| - .filter_map(|(field, offset)| { |
2928 |
| - let field_ty: Ty<'_> = field.ty(tcx, substs); |
2929 |
| - let field_ty: Ty<'_> = |
2930 |
| - tcx.normalize_erasing_regions(ParamEnv::empty(), field_ty); |
2931 |
| - |
2932 |
| - if field_ty.is_phantom_data() { |
2933 |
| - return None; |
2934 |
| - } |
2935 |
| - |
2936 |
| - //visited.push(field_ty); |
2937 |
| - let mut child = typetree_from_ty(field_ty, tcx, depth + 1, safety, visited, span).0; |
2938 |
| - |
2939 |
| - for c in &mut child { |
2940 |
| - if c.offset == -1 { |
2941 |
| - c.offset = offset.bytes() as isize |
2942 |
| - } else { |
2943 |
| - c.offset += offset.bytes() as isize; |
2944 |
| - } |
2945 |
| - } |
2946 |
| - |
2947 |
| - Some(child) |
2948 |
| - }) |
2949 |
| - .flatten() |
2950 |
| - .collect::<Vec<Type>>(); |
2951 |
| - |
2952 |
| - visited.pop(); |
2953 |
| - let ret_tt = TypeTree(fields); |
2954 |
| - return ret_tt; |
2955 |
| - } else if adt_def.is_enum() { |
2956 |
| - // Enzyme can't represent enums, so let it figure it out itself, without seeeding |
2957 |
| - // typetree |
2958 |
| - //unimplemented!("adt that is an enum"); |
2959 |
| - } else { |
2960 |
| - //let ty_name = tcx.def_path_debug_str(adt_def.did()); |
2961 |
| - //tcx.sess.emit_fatal(UnsupportedUnion { ty_name }); |
2962 |
| - } |
2963 |
| - } |
2964 |
| - |
2965 |
| - if ty.is_simd() { |
2966 |
| - trace!("simd"); |
2967 |
| - let (_size, inner_ty) = ty.simd_size_and_type(tcx); |
2968 |
| - //visited.push(inner_ty); |
2969 |
| - let _sub_tt = typetree_from_ty(inner_ty, tcx, depth + 1, safety, visited, span); |
2970 |
| - //let tt = TypeTree( |
2971 |
| - // std::iter::repeat(subtt) |
2972 |
| - // .take(*count as usize) |
2973 |
| - // .enumerate() |
2974 |
| - // .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) |
2975 |
| - // .flatten() |
2976 |
| - // .collect(), |
2977 |
| - //); |
2978 |
| - // TODO |
2979 |
| - visited.pop(); |
2980 |
| - return TypeTree::new(); |
2981 |
| - } |
2982 |
| - |
2983 |
| - if ty.is_array() { |
2984 |
| - let (stride, count) = match fields { |
2985 |
| - FieldsShape::Array { stride: s, count: c } => (s, c), |
2986 |
| - _ => panic!(""), |
2987 |
| - }; |
2988 |
| - let byte_stride = stride.bytes_usize(); |
2989 |
| - let byte_max_size = max_size.bytes_usize(); |
2990 |
| - |
2991 |
| - assert!(byte_stride * *count as usize == byte_max_size); |
2992 |
| - if (*count as usize) == 0 { |
2993 |
| - return TypeTree::new(); |
2994 |
| - } |
2995 |
| - let sub_ty = ty.builtin_index().unwrap(); |
2996 |
| - //visited.push(sub_ty); |
2997 |
| - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); |
2998 |
| - |
2999 |
| - // calculate size of subtree |
3000 |
| - let param_env_and = ParamEnvAnd { param_env: ParamEnv::empty(), value: sub_ty }; |
3001 |
| - let size = tcx.layout_of(param_env_and).unwrap().size.bytes() as usize; |
3002 |
| - let tt = TypeTree( |
3003 |
| - std::iter::repeat(subtt) |
3004 |
| - .take(*count as usize) |
3005 |
| - .enumerate() |
3006 |
| - .map(|(idx, x)| x.0.into_iter().map(move |x| x.add_offset((idx * size) as isize))) |
3007 |
| - .flatten() |
3008 |
| - .collect(), |
3009 |
| - ); |
3010 |
| - |
3011 |
| - visited.pop(); |
3012 |
| - return tt; |
3013 |
| - } |
3014 |
| - |
3015 |
| - if ty.is_slice() { |
3016 |
| - let sub_ty = ty.builtin_index().unwrap(); |
3017 |
| - //visited.push(sub_ty); |
3018 |
| - let subtt = typetree_from_ty(sub_ty, tcx, depth + 1, safety, visited, span); |
3019 |
| - |
3020 |
| - visited.pop(); |
3021 |
| - return subtt; |
3022 |
| - } |
3023 |
| - |
3024 |
| - visited.pop(); |
3025 |
| - TypeTree::new() |
3026 |
| -} |
0 commit comments