Skip to content

Commit a56c9dc

Browse files
committed
Add linking support for x86amx
1 parent 56772c6 commit a56c9dc

File tree

2 files changed

+111
-39
lines changed

2 files changed

+111
-39
lines changed

compiler/rustc_codegen_llvm/src/abi.rs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
use std::borrow::Borrow;
2-
use std::cmp;
2+
use std::{cmp, iter};
33

4-
use itertools::zip_eq;
54
use libc::c_uint;
65
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
76
use rustc_codegen_ssa::MemFlags;
7+
use rustc_codegen_ssa::common::TypeKind;
88
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
99
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
1010
use rustc_codegen_ssa::traits::*;
@@ -362,15 +362,15 @@ fn match_intrinsic_signature<'ll>(
362362
);
363363
}
364364

365-
if rust_return_ty != llvm_return_ty {
365+
if !equate_ty(cx, rust_return_ty, llvm_return_ty) {
366366
error!(
367367
"Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as return type for `{fn_name}`"
368368
);
369369
}
370370
for (idx, (&rust_argument_ty, llvm_argument_ty)) in
371-
zip_eq(rust_argument_tys, llvm_argument_tys).enumerate()
371+
iter::zip(rust_argument_tys, llvm_argument_tys).enumerate()
372372
{
373-
if rust_argument_ty != llvm_argument_ty {
373+
if !equate_ty(cx, rust_argument_ty, llvm_argument_ty) {
374374
error!(
375375
"Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as argument {idx} for `{fn_name}`"
376376
);
@@ -380,6 +380,30 @@ fn match_intrinsic_signature<'ll>(
380380
fn_ty
381381
}
382382

383+
fn equate_ty<'ll>(cx: &CodegenCx<'ll, '_>, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
384+
if rust_ty == llvm_ty {
385+
return true;
386+
}
387+
if cx.type_kind(llvm_ty) == TypeKind::X86_AMX && cx.type_kind(rust_ty) == TypeKind::Vector {
388+
let element_count = cx.vector_length(rust_ty);
389+
let element_ty = cx.element_type(rust_ty);
390+
391+
let element_size_bits = match cx.type_kind(element_ty) {
392+
TypeKind::Half => 16,
393+
TypeKind::Float => 32,
394+
TypeKind::Double => 64,
395+
TypeKind::FP128 => 128,
396+
TypeKind::Integer => cx.int_width(element_ty),
397+
TypeKind::Pointer => cx.int_width(cx.isize_ty),
398+
_ => bug!("Vector element type `{element_ty:?}` not one of integer, float or pointer"),
399+
};
400+
let vector_size_bits = element_size_bits * element_count as u64;
401+
402+
return vector_size_bits == 8192;
403+
}
404+
return false;
405+
}
406+
383407
impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
384408
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
385409
match &self.ret.mode {
@@ -477,7 +501,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
477501
// - it has been superceded by something else, so the intrinsic was removed entirely
478502
//
479503
// anyway, let's log it
480-
tracing::warn!(
504+
tracing::debug!(
481505
"Couldn't find intrinsic `{}`, either invalid or deprecated",
482506
str::from_utf8(name).unwrap()
483507
);

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 81 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use rustc_span::Span;
2727
use rustc_target::callconv::FnAbi;
2828
use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target};
2929
use smallvec::SmallVec;
30-
use tracing::{debug, instrument};
30+
use tracing::{debug, instrument, warn};
3131

3232
use crate::abi::FnAbiLlvmExt;
3333
use crate::attributes;
@@ -67,7 +67,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
6767
) -> &'ll Value {
6868
debug!("call {:?} with args ({:?})", llfn, args);
6969

70-
let args = self.check_call("call", llty, llfn, args);
70+
let args = self.cast_arguments("call", llty, llfn, args, false);
7171
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
7272
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
7373
if let Some(funclet_bundle) = funclet_bundle {
@@ -79,7 +79,7 @@ impl<'a, 'll> SBuilder<'a, 'll> {
7979
self.llbuilder,
8080
llty,
8181
llfn,
82-
args.as_ptr() as *const &llvm::Value,
82+
args.as_ptr(),
8383
args.len() as c_uint,
8484
bundles.as_ptr(),
8585
bundles.len() as c_uint,
@@ -101,6 +101,22 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
101101
unsafe { llvm::LLVMBuildBitCast(self.llbuilder, val, dest_ty, UNNAMED) }
102102
}
103103

104+
pub(crate) fn cast_vector_to_tile(&mut self, val: &'ll Value) -> &'ll Value {
105+
let vector_type = self.cx.val_ty(val);
106+
107+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
108+
self.call_intrinsic("llvm.x86.cast.vector.to.tile", &[vector_type], &[val])
109+
}
110+
111+
pub(crate) fn cast_tile_to_vector(
112+
&mut self,
113+
val: &'ll Value,
114+
vector_type: &'ll Type,
115+
) -> &'ll Value {
116+
assert!(self.cx.type_kind(vector_type) == TypeKind::Vector);
117+
self.call_intrinsic("llvm.x86.cast.tile.to.vector", &[vector_type], &[val])
118+
}
119+
104120
pub(crate) fn ret_void(&mut self) {
105121
llvm::LLVMBuildRetVoid(self.llbuilder);
106122
}
@@ -349,7 +365,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
349365
) -> &'ll Value {
350366
debug!("invoke {:?} with args ({:?})", llfn, args);
351367

352-
let args = self.check_call("invoke", llty, llfn, args);
368+
let args = self.cast_arguments("invoke", llty, llfn, args, fn_abi.is_some());
353369
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
354370
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
355371
if let Some(funclet_bundle) = funclet_bundle {
@@ -381,8 +397,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
381397
};
382398
if let Some(fn_abi) = fn_abi {
383399
fn_abi.apply_attrs_callsite(self, invoke);
400+
self.cast_return(fn_abi, llfn, invoke)
401+
} else {
402+
invoke
384403
}
385-
invoke
386404
}
387405

388406
fn unreachable(&mut self) {
@@ -1348,7 +1366,7 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
13481366
) -> &'ll Value {
13491367
debug!("call {:?} with args ({:?})", llfn, args);
13501368

1351-
let args = self.check_call("call", llty, llfn, args);
1369+
let args = self.cast_arguments("call", llty, llfn, args, fn_abi.is_some());
13521370
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
13531371
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
13541372
if let Some(funclet_bundle) = funclet_bundle {
@@ -1378,8 +1396,10 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
13781396
};
13791397
if let Some(fn_abi) = fn_abi {
13801398
fn_abi.apply_attrs_callsite(self, call);
1399+
self.cast_return(fn_abi, llfn, call)
1400+
} else {
1401+
call
13811402
}
1382-
call
13831403
}
13841404

13851405
fn zext(&mut self, val: &'ll Value, dest_ty: &'ll Type) -> &'ll Value {
@@ -1540,45 +1560,47 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15401560
ret.expect("LLVM does not have support for catchret")
15411561
}
15421562

1543-
fn check_call<'b>(
1563+
fn cast_arguments<'b>(
15441564
&mut self,
15451565
typ: &str,
15461566
fn_ty: &'ll Type,
15471567
llfn: &'ll Value,
15481568
args: &'b [&'ll Value],
1569+
has_fnabi: bool,
15491570
) -> Cow<'b, [&'ll Value]> {
1550-
assert!(
1551-
self.cx.type_kind(fn_ty) == TypeKind::Function,
1552-
"builder::{typ} not passed a function, but {fn_ty:?}"
1571+
assert_eq!(
1572+
self.cx.type_kind(fn_ty),
1573+
TypeKind::Function,
1574+
"{typ} not passed a function, but {fn_ty:?}"
15531575
);
15541576

15551577
let param_tys = self.cx.func_params_types(fn_ty);
15561578

1557-
let all_args_match = iter::zip(&param_tys, args.iter().map(|&v| self.cx.val_ty(v)))
1558-
.all(|(expected_ty, actual_ty)| *expected_ty == actual_ty);
1579+
let mut casted_args = Cow::Borrowed(args);
15591580

1560-
if all_args_match {
1561-
return Cow::Borrowed(args);
1562-
}
1581+
for (i, (expected_ty, &actual_val)) in iter::zip(param_tys, args).enumerate() {
1582+
let actual_ty = self.cx.val_ty(actual_val);
15631583

1564-
let casted_args: Vec<_> = iter::zip(param_tys, args)
1565-
.enumerate()
1566-
.map(|(i, (expected_ty, &actual_val))| {
1567-
let actual_ty = self.cx.val_ty(actual_val);
1568-
if expected_ty != actual_ty {
1569-
debug!(
1570-
"type mismatch in function call of {:?}. \
1571-
Expected {:?} for param {}, got {:?}; injecting bitcast",
1572-
llfn, expected_ty, i, actual_ty
1584+
if expected_ty != actual_ty {
1585+
warn!(
1586+
"type mismatch in function call of {llfn:?}. \
1587+
Expected {expected_ty:?} for param {i}, got {actual_ty:?}; injecting bitcast",
1588+
);
1589+
1590+
casted_args.to_mut()[i] = if self.cx.type_kind(expected_ty) == TypeKind::X86_AMX {
1591+
// we can't do `cast_return` in without `FnAbi`
1592+
assert!(
1593+
has_fnabi,
1594+
"Found `x86amx` for parameter {i} in function call of {llfn:?}, but not able to get Rust return type"
15731595
);
1574-
self.bitcast(actual_val, expected_ty)
1596+
self.cast_vector_to_tile(actual_val)
15751597
} else {
1576-
actual_val
1598+
self.bitcast(actual_val, expected_ty)
15771599
}
1578-
})
1579-
.collect();
1600+
}
1601+
}
15801602

1581-
Cow::Owned(casted_args)
1603+
casted_args
15821604
}
15831605

15841606
pub(crate) fn va_arg(&mut self, list: &'ll Value, ty: &'ll Type) -> &'ll Value {
@@ -1591,7 +1613,7 @@ impl<'a, 'll, CX: Borrow<SCx<'ll>>> GenericBuilder<'a, 'll, CX> {
15911613
llfn: &'ll Value,
15921614
args: &[&'ll Value],
15931615
) -> &'ll Value {
1594-
let args = self.check_call("simple call", fn_ty, llfn, args);
1616+
let args = self.cast_arguments("simple call", fn_ty, llfn, args, false);
15951617

15961618
unsafe {
15971619
llvm::LLVMBuildCall2(
@@ -1692,6 +1714,30 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
16921714
self.simple_call(fn_ty, llfn, &[ptr1, ptr2, num])
16931715
}
16941716

1717+
fn cast_return(
1718+
&mut self,
1719+
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
1720+
llfn: &'ll Value,
1721+
ret: &'ll Value,
1722+
) -> &'ll Value {
1723+
let expected_ty = fn_abi.llvm_return_type(self.cx);
1724+
let actual_ty = self.cx.val_ty(ret);
1725+
1726+
if expected_ty != actual_ty {
1727+
warn!(
1728+
"Type mismatch in function call of {llfn:?}. \
1729+
Expected {expected_ty:?} for return type, got {actual_ty:?}"
1730+
);
1731+
if self.cx.type_kind(actual_ty) == TypeKind::X86_AMX {
1732+
self.cast_tile_to_vector(ret, expected_ty)
1733+
} else {
1734+
self.bitcast(ret, expected_ty)
1735+
}
1736+
} else {
1737+
ret
1738+
}
1739+
}
1740+
16951741
pub(crate) fn landing_pad(
16961742
&mut self,
16971743
ty: &'ll Type,
@@ -1721,7 +1767,7 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17211767
) -> &'ll Value {
17221768
debug!("invoke {:?} with args ({:?})", llfn, args);
17231769

1724-
let args = self.check_call("callbr", llty, llfn, args);
1770+
let args = self.cast_arguments("callbr", llty, llfn, args, fn_abi.is_some());
17251771
let funclet_bundle = funclet.map(|funclet| funclet.bundle());
17261772
let mut bundles: SmallVec<[_; 2]> = SmallVec::new();
17271773
if let Some(funclet_bundle) = funclet_bundle {
@@ -1754,8 +1800,10 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
17541800
};
17551801
if let Some(fn_abi) = fn_abi {
17561802
fn_abi.apply_attrs_callsite(self, callbr);
1803+
self.cast_return(fn_abi, llfn, callbr)
1804+
} else {
1805+
callbr
17571806
}
1758-
callbr
17591807
}
17601808

17611809
// Emits CFI pointer type membership tests.

0 commit comments

Comments
 (0)