Skip to content

Commit b9568c9

Browse files
committed
Add linking support for x86amx and bf16 (only scalar and vector)
1 parent af06ce3 commit b9568c9

File tree

2 files changed

+107
-12
lines changed

2 files changed

+107
-12
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::iter::zip;
55
use libc::c_uint;
66
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
77
use rustc_codegen_ssa::MemFlags;
8+
use rustc_codegen_ssa::common::TypeKind;
89
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
910
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1011
use rustc_codegen_ssa::traits::*;
@@ -327,6 +328,47 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
327328
fn apply_attrs_callsite(&self, bx: &mut Builder<'_, 'll, 'tcx>, callsite: &'ll Value);
328329
}
329330

331+
fn equate_ty<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
332+
if rust_ty == llvm_ty {
333+
return true;
334+
}
335+
match cx.type_kind(llvm_ty) {
336+
TypeKind::X86_AMX => {
337+
// we will insert casts from/to x86amx in callsite, so this is fine
338+
if cx.type_kind(rust_ty) == TypeKind::Vector {
339+
let element_count = cx.vector_length(rust_ty);
340+
let element_ty = cx.element_type(rust_ty);
341+
let element_size_bits = match cx.type_kind(element_ty) {
342+
TypeKind::Half => 16,
343+
TypeKind::Float => 32,
344+
TypeKind::Double => 64,
345+
TypeKind::FP128 => 128,
346+
TypeKind::Integer => cx.int_width(element_ty),
347+
TypeKind::Pointer => cx.int_width(cx.isize_ty),
348+
_ => bug!(
349+
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
350+
),
351+
};
352+
element_size_bits * element_count as u64 == 8192
353+
} else {
354+
false
355+
}
356+
}
357+
TypeKind::BFloat => rust_ty == cx.type_i16(),
358+
TypeKind::Vector => {
359+
let element_count = cx.vector_length(llvm_ty);
360+
let element_ty = cx.element_type(llvm_ty);
361+
362+
if element_ty == cx.type_bf16() {
363+
rust_ty == cx.type_vector(cx.type_i16(), element_count as u64)
364+
} else {
365+
false
366+
}
367+
}
368+
_ => false,
369+
}
370+
}
371+
330372
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
331373
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
332374
match &self.ret.mode {
@@ -419,7 +461,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
419461
));
420462
}
421463

422-
if actual_return_ty != expected_return_ty {
464+
if !equate_ty(cx, actual_return_ty, expected_return_ty) {
423465
cx.tcx.dcx().fatal(format!(
424466
"Intrinsic signature mismatch: expected {expected_return_ty:?} as return type for `{}`, found {actual_return_ty:?}",
425467
str::from_utf8(name).unwrap()
@@ -428,7 +470,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
428470
for (idx, (actual_argument_ty, expected_argument_ty)) in
429471
zip(actual_argument_tys, expected_argument_tys).enumerate()
430472
{
431-
if actual_argument_ty != expected_argument_ty {
473+
if !equate_ty(cx, actual_argument_ty, expected_argument_ty) {
432474
cx.tcx.dcx().fatal(format!(
433475
"Intrinsic signature mismatch: expected {expected_argument_ty:?} as argument {idx} for `{}`, found {actual_argument_ty:?}",
434476
str::from_utf8(name).unwrap()

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 63 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
6767
) -> &'ll Value {
6868
debug!("call {:?} with args ({:?})", llfn, args);
6969

70-
let args = self.check_call("call", llty, llfn, args);
70+
let args = self.cast_arguments("call", llty, llfn, args);
7171
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
7272
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
7373
if let Some(funclet_bundle) = funclet_bundle {
@@ -101,6 +101,24 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
101101
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
102102
}
103103

104+
pub(crate) fn cast_vector_to_tile(&mut self, val: &'ll Value) -> &'ll Value {
105+
let vector_type = self.cx.val_ty(val);
106+
107+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
108+
self.call_intrinsic("llvm.x86.cast.vector.to.tile", &[vector_type], &[val])
109+
}
110+
111+
pub(crate) fn cast_tile_to_vector(
112+
&mut self,
113+
val: &'ll Value,
114+
vector_type: &'ll Type,
115+
) -> &'ll Value {
116+
assert!(self.cx.val_ty(val) == self.cx.type_x86amx());
117+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
118+
119+
self.call_intrinsic("llvm.x86.cast.tile.to.vector", &[vector_type], &[val])
120+
}
121+
104122
pub(crate) fn ret_void(&mut self) {
105123
llvm::LLVMBuildRetVoid(self.llbuilder);
106124
}
@@ -349,7 +367,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
349367
) -> &'ll Value {
350368
debug!("invoke {:?} with args ({:?})", llfn, args);
351369

352-
let args = self.check_call("invoke", llty, llfn, args);
370+
let args = self.cast_arguments("invoke", llty, llfn, args);
353371
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
354372
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
355373
if let Some(funclet_bundle) = funclet_bundle {
@@ -381,8 +399,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
381399
};
382400
if let Some(fn_abi) = fn_abi {
383401
fn_abi.apply_attrs_callsite(self, invoke);
402+
self.cast_return(fn_abi, llfn, invoke)
403+
} else {
404+
invoke
384405
}
385-
invoke
386406
}
387407

388408
fn unreachable(&mut self) {
@@ -1348,7 +1368,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
13481368
) -> &'ll Value {
13491369
debug!("call {:?} with args ({:?})", llfn, args);
13501370

1351-
let args = self.check_call("call", llty, llfn, args);
1371+
let args = self.cast_arguments("call", llty, llfn, args);
13521372
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
13531373
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
13541374
if let Some(funclet_bundle) = funclet_bundle {
@@ -1378,8 +1398,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
13781398
};
13791399
if let Some(fn_abi) = fn_abi {
13801400
fn_abi.apply_attrs_callsite(self, call);
1401+
self.cast_return(fn_abi, llfn, call)
1402+
} else {
1403+
call
13811404
}
1382-
call
13831405
}
13841406

13851407
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
@@ -1540,7 +1562,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15401562
ret.expect("LLVM does not have support for catchret")
15411563
}
15421564

1543-
fn check_call<'b>(
1565+
fn cast_arguments<'b>(
15441566
&mut self,
15451567
typ: &str,
15461568
fn_ty: &'ll Type,
@@ -1571,7 +1593,11 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15711593
Expected {:?} for param {}, got {:?}; injecting bitcast",
15721594
llfn, expected_ty, i, actual_ty
15731595
);
1574-
self.bitcast(actual_val, expected_ty)
1596+
if self.cx.type_kind(expected_ty) == TypeKind::X86_AMX {
1597+
self.cast_vector_to_tile(actual_val)
1598+
} else {
1599+
self.bitcast(actual_val, expected_ty)
1600+
}
15751601
} else {
15761602
actual_val
15771603
}
@@ -1591,7 +1617,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15911617
llfn: &'ll Value,
15921618
args: &[&'ll Value],
15931619
) -> &'ll Value {
1594-
let args = self.check_call("simple call", fn_ty, llfn, args);
1620+
let args = self.cast_arguments("simple call", fn_ty, llfn, args);
15951621

15961622
unsafe {
15971623
llvm::LLVMBuildCall2(
@@ -1692,6 +1718,31 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16921718
self.simple_call(fn_ty, llfn, &[ptr1, ptr2, num])
16931719
}
16941720

1721+
fn cast_return(
1722+
&mut self,
1723+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
1724+
llfn: &'ll Value,
1725+
ret: &'ll Value,
1726+
) -> &'ll Value {
1727+
let expected_ty = fn_abi.llvm_return_type(self.cx);
1728+
let actual_ty = self.cx.val_ty(ret);
1729+
1730+
if expected_ty != actual_ty {
1731+
debug!(
1732+
"type mismatch in function call of {:?}. \
1733+
Expected {:?} for return value, got {:?}; injecting bitcast",
1734+
llfn, expected_ty, actual_ty
1735+
);
1736+
if self.cx.type_kind(actual_ty) == TypeKind::X86_AMX {
1737+
self.cast_tile_to_vector(ret, expected_ty)
1738+
} else {
1739+
self.bitcast(ret, expected_ty)
1740+
}
1741+
} else {
1742+
ret
1743+
}
1744+
}
1745+
16951746
pub(crate) fn landing_pad(
16961747
&mut self,
16971748
ty: &'ll Type,
@@ -1721,7 +1772,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17211772
) -> &'ll Value {
17221773
debug!("invoke {:?} with args ({:?})", llfn, args);
17231774

1724-
let args = self.check_call("callbr", llty, llfn, args);
1775+
let args = self.cast_arguments("callbr", llty, llfn, args);
17251776
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
17261777
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
17271778
if let Some(funclet_bundle) = funclet_bundle {
@@ -1754,8 +1805,10 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17541805
};
17551806
if let Some(fn_abi) = fn_abi {
17561807
fn_abi.apply_attrs_callsite(self, callbr);
1808+
self.cast_return(fn_abi, llfn, callbr)
1809+
} else {
1810+
callbr
17571811
}
1758-
callbr
17591812
}
17601813

17611814
// Emits CFI pointer type membership tests.

0 commit comments

Comments
 (0)