Skip to content

Change codegen of LLVM intrinsics to be name-based, and add llvm linkage support for x86amx #140763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
fn checked_binop(
&mut self,
oop: OverflowOp,
typ: Ty<'_>,
typ: Ty<'tcx>,
lhs: Self::Value,
rhs: Self::Value,
) -> (Self::Value, Self::Value) {
Expand Down
8 changes: 6 additions & 2 deletions compiler/rustc_codegen_gcc/src/type_of.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Write;

use gccjit::{Struct, Type};
use gccjit::{RValue, Struct, Type};
use rustc_abi as abi;
use rustc_abi::Primitive::*;
use rustc_abi::{
Expand Down Expand Up @@ -373,7 +373,11 @@ impl<'gcc, 'tcx> LayoutTypeCodegenMethods<'tcx> for CodegenCx<'gcc, 'tcx> {
unimplemented!();
}

fn fn_decl_backend_type(&self, fn_abi: &FnAbi<'tcx, Ty<'tcx>>) -> Type<'gcc> {
fn fn_decl_backend_type(
&self,
fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
_fn_ptr: RValue<'gcc>,
) -> Type<'gcc> {
// FIXME(antoyo): Should we do something with `FnAbiGcc::fn_attributes`?
let FnAbiGcc { return_type, arguments_type, is_c_variadic, .. } = fn_abi.gcc_type(self);
self.context.new_function_pointer_type(None, return_type, &arguments_type, is_c_variadic)
Expand Down
180 changes: 161 additions & 19 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::borrow::Borrow;
use std::cmp;
use std::{cmp, iter};

use libc::c_uint;
use rustc_abi::{BackendRepr, HasDataLayout, Primitive, Reg, RegKind, Size};
use rustc_codegen_ssa::MemFlags;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::{PlaceRef, PlaceValue};
use rustc_codegen_ssa::traits::*;
Expand All @@ -19,7 +20,7 @@ use smallvec::SmallVec;

use crate::attributes::{self, llfn_attrs_from_instance};
use crate::builder::Builder;
use crate::context::CodegenCx;
use crate::context::{CodegenCx, GenericCx, SCx};
use crate::llvm::{self, Attribute, AttributePlace};
use crate::type_::Type;
use crate::type_of::LayoutLlvmExt;
Expand Down Expand Up @@ -308,7 +309,9 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
}

pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_argument_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type>;
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>, name: &[u8]) -> &'ll Type;
fn ptr_to_llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type;
fn llvm_cconv(&self, cx: &CodegenCx<'ll, 'tcx>) -> llvm::CallConv;

Expand All @@ -324,27 +327,132 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
fn apply_attrs_callsite(&self, bx: &mut Builder<'_, 'll, 'tcx>, callsite: &'ll Value);
}

/// checks that the Rust signature of a **non-overloaded** llvm intrinsic is correct
fn match_intrinsic_signature<'ll>(
cx: &CodegenCx<'ll, '_>,
intrinsic: llvm::Intrinsic,
rust_return_ty: &'ll Type,
rust_argument_tys: &[&'ll Type],
name: &[u8],
) -> &'ll Type {
macro_rules! error {
($($t:tt)*) => {
cx.tcx.dcx().fatal(format!($($t)*, fn_name=str::from_utf8(name).unwrap()))
};
}

let base_name = intrinsic.base_name();
if name != base_name {
error!(
"Unsupported overload `{fn_name}` for non-overloaded intrinsic `{}`",
str::from_utf8(base_name).unwrap()
);
}

let fn_ty = cx.intrinsic_type(intrinsic, &[]);

let llvm_return_ty = cx.get_return_type(fn_ty);
let llvm_argument_tys = cx.func_params_types(fn_ty);

if rust_argument_tys.len() != llvm_argument_tys.len() {
error!(
"Intrinsic signature mismatch: expected {} arguments for `{fn_name}`, found {} arguments",
llvm_argument_tys.len(),
rust_argument_tys.len()
);
}

if !cx.equate_ty(rust_return_ty, llvm_return_ty) {
error!(
"Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as return type for `{fn_name}`"
);
}
for (idx, (&rust_argument_ty, llvm_argument_ty)) in
iter::zip(rust_argument_tys, llvm_argument_tys).enumerate()
{
if !cx.equate_ty(rust_argument_ty, llvm_argument_ty) {
error!(
"Intrinsic signature mismatch: could not match `{rust_return_ty:?}` (found) with {llvm_return_ty:?} (expected) as argument {idx} for `{fn_name}`"
);
}
}

fn_ty
}

impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
pub(crate) fn equate_ty(&self, rust_ty: &'ll Type, llvm_ty: &'ll Type) -> bool {
if rust_ty == llvm_ty {
return true;
}

match self.type_kind(llvm_ty) {
TypeKind::X86_AMX if self.type_kind(rust_ty) == TypeKind::Vector => {
let element_count = self.vector_length(rust_ty);
let element_ty = self.element_type(rust_ty);

let element_size_bits = match self.type_kind(element_ty) {
TypeKind::Half => 16,
TypeKind::Float => 32,
TypeKind::Double => 64,
TypeKind::FP128 => 128,
TypeKind::Integer => self.int_width(element_ty),
TypeKind::Pointer => self.int_width(self.isize_ty()),
_ => bug!(
"Vector element type `{element_ty:?}` not one of integer, float or pointer"
),
};
let vector_size_bits = element_size_bits * element_count as u64;

vector_size_bits == 8192
}
TypeKind::BFloat => rust_ty == self.type_i16(),
TypeKind::Vector if self.type_kind(rust_ty) == TypeKind::Vector => {
let llvm_element_count = self.vector_length(llvm_ty);
let rust_element_count = self.vector_length(rust_ty);

if llvm_element_count != rust_element_count {
return false;
}

let llvm_element_ty = self.element_type(llvm_ty);
let rust_element_ty = self.element_type(rust_ty);

if llvm_element_ty == self.type_bf16() {
rust_element_ty == self.type_i16()
} else {
false
}
}
_ => false,
}
}
}

impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
fn llvm_return_type(&self, cx: &CodegenCx<'ll, 'tcx>) -> &'ll Type {
match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => cx.type_void(),
}
}

fn llvm_argument_types(&self, cx: &CodegenCx<'ll, 'tcx>) -> Vec<&'ll Type> {
let indirect_return = matches!(self.ret.mode, PassMode::Indirect { .. });

// Ignore "extra" args from the call site for C variadic functions.
// Only the "fixed" args are part of the LLVM function signature.
let args =
if self.c_variadic { &self.args[..self.fixed_count as usize] } else { &self.args };

// This capacity calculation is approximate.
let mut llargument_tys = Vec::with_capacity(
self.args.len() + if let PassMode::Indirect { .. } = self.ret.mode { 1 } else { 0 },
);
let mut llargument_tys =
Vec::with_capacity(args.len() + if indirect_return { 1 } else { 0 });

let llreturn_ty = match &self.ret.mode {
PassMode::Ignore => cx.type_void(),
PassMode::Direct(_) | PassMode::Pair(..) => self.ret.layout.immediate_llvm_type(cx),
PassMode::Cast { cast, pad_i32: _ } => cast.llvm_type(cx),
PassMode::Indirect { .. } => {
llargument_tys.push(cx.type_ptr());
cx.type_void()
}
};
if indirect_return {
llargument_tys.push(cx.type_ptr());
}

for arg in args {
// Note that the exact number of arguments pushed here is carefully synchronized with
Expand Down Expand Up @@ -391,10 +499,44 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
llargument_tys.push(llarg_ty);
}

llargument_tys
}

fn llvm_type(&self, cx: &CodegenCx<'ll, 'tcx>, name: &[u8]) -> &'ll Type {
let return_ty = self.llvm_return_type(cx);
let argument_tys = self.llvm_argument_types(cx);

if name.starts_with(b"llvm.") {
if let Some(intrinsic) = llvm::Intrinsic::lookup(name)
&& !intrinsic.is_overloaded()
{
if !intrinsic.is_overloaded() {
// FIXME: also do this for overloaded intrinsics
return match_intrinsic_signature(
cx,
intrinsic,
return_ty,
&argument_tys,
name,
);
}
} else {
// it's one of 2 cases,
// - either the base name is invalid
// - it has been superceded by something else, so the intrinsic was removed entirely
//
// anyway, let's log it
tracing::debug!(
"Couldn't find intrinsic `{}`, either invalid or deprecated",
str::from_utf8(name).unwrap()
);
}
}

if self.c_variadic {
cx.type_variadic_func(&llargument_tys, llreturn_ty)
cx.type_variadic_func(&argument_tys, return_ty)
} else {
cx.type_func(&llargument_tys, llreturn_ty)
cx.type_func(&argument_tys, return_ty)
}
}

Expand Down
Loading
Loading