Skip to content

Commit

Permalink
cranelift: Implement scalar FMA on x86 (bytecodealliance#4460)
Browse files Browse the repository at this point in the history
x86 does not have dedicated instructions for scalar FMA, lower
to a libcall which seems to be what llvm does.
  • Loading branch information
afonso360 authored Aug 3, 2022
1 parent ff6082c commit 709716b
Show file tree
Hide file tree
Showing 13 changed files with 167 additions and 50 deletions.
57 changes: 56 additions & 1 deletion cranelift/codegen/src/ir/libcall.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Naming well-known routines in the runtime library.
use crate::ir::{types, ExternalName, FuncRef, Function, Opcode, Type};
use crate::ir::{types, AbiParam, ExternalName, FuncRef, Function, Opcode, Signature, Type};
use crate::isa::CallConv;
use core::fmt;
use core::str::FromStr;
#[cfg(feature = "enable-serde")]
Expand Down Expand Up @@ -50,6 +51,10 @@ pub enum LibCall {
NearestF32,
/// nearest.f64
NearestF64,
/// fma.f32
FmaF32,
/// fma.f64
FmaF64,
/// libc.memcpy
Memcpy,
/// libc.memset
Expand Down Expand Up @@ -91,6 +96,8 @@ impl FromStr for LibCall {
"TruncF64" => Ok(Self::TruncF64),
"NearestF32" => Ok(Self::NearestF32),
"NearestF64" => Ok(Self::NearestF64),
"FmaF32" => Ok(Self::FmaF32),
"FmaF64" => Ok(Self::FmaF64),
"Memcpy" => Ok(Self::Memcpy),
"Memset" => Ok(Self::Memset),
"Memmove" => Ok(Self::Memmove),
Expand Down Expand Up @@ -124,13 +131,15 @@ impl LibCall {
Opcode::Floor => Self::FloorF32,
Opcode::Trunc => Self::TruncF32,
Opcode::Nearest => Self::NearestF32,
Opcode::Fma => Self::FmaF32,
_ => return None,
},
types::F64 => match opcode {
Opcode::Ceil => Self::CeilF64,
Opcode::Floor => Self::FloorF64,
Opcode::Trunc => Self::TruncF64,
Opcode::Nearest => Self::NearestF64,
Opcode::Fma => Self::FmaF64,
_ => return None,
},
_ => return None,
Expand All @@ -157,13 +166,59 @@ impl LibCall {
TruncF64,
NearestF32,
NearestF64,
FmaF32,
FmaF64,
Memcpy,
Memset,
Memmove,
Memcmp,
ElfTlsGetAddr,
]
}

/// Get a [Signature] for the function targeted by this [LibCall].
pub fn signature(&self, call_conv: CallConv) -> Signature {
use types::*;
let mut sig = Signature::new(call_conv);

match self {
LibCall::UdivI64
| LibCall::SdivI64
| LibCall::UremI64
| LibCall::SremI64
| LibCall::IshlI64
| LibCall::UshrI64
| LibCall::SshrI64 => {
sig.params.push(AbiParam::new(I64));
sig.params.push(AbiParam::new(I64));
sig.returns.push(AbiParam::new(I64));
}
LibCall::CeilF32 | LibCall::FloorF32 | LibCall::TruncF32 | LibCall::NearestF32 => {
sig.params.push(AbiParam::new(F32));
sig.returns.push(AbiParam::new(F32));
}
LibCall::TruncF64 | LibCall::FloorF64 | LibCall::CeilF64 | LibCall::NearestF64 => {
sig.params.push(AbiParam::new(F64));
sig.returns.push(AbiParam::new(F64));
}
LibCall::FmaF32 | LibCall::FmaF64 => {
let ty = if *self == LibCall::FmaF32 { F32 } else { F64 };

sig.params.push(AbiParam::new(ty));
sig.params.push(AbiParam::new(ty));
sig.params.push(AbiParam::new(ty));
sig.returns.push(AbiParam::new(ty));
}
LibCall::Probestack
| LibCall::Memcpy
| LibCall::Memset
| LibCall::Memmove
| LibCall::Memcmp
| LibCall::ElfTlsGetAddr => unimplemented!(),
}

sig
}
}

/// Get a function reference for the probestack function in `func`.
Expand Down
2 changes: 1 addition & 1 deletion cranelift/codegen/src/isa/aarch64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1551,7 +1551,7 @@ impl LowerBackend for AArch64Backend {
type MInst = Inst;

fn lower<C: LowerCtx<I = Inst>>(&self, ctx: &mut C, ir_inst: IRInst) -> CodegenResult<()> {
lower_inst::lower_insn_to_regs(ctx, ir_inst, &self.flags, &self.isa_flags)
lower_inst::lower_insn_to_regs(ctx, ir_inst, &self.triple, &self.flags, &self.isa_flags)
}

fn lower_branch_group<C: LowerCtx<I = Inst>>(
Expand Down
14 changes: 11 additions & 3 deletions cranelift/codegen/src/isa/aarch64/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use regalloc2::PReg;
use std::boxed::Box;
use std::convert::TryFrom;
use std::vec::Vec;
use target_lexicon::Triple;

type BoxCallInfo = Box<CallInfo>;
type BoxCallIndInfo = Box<CallIndInfo>;
Expand All @@ -40,6 +41,7 @@ type BoxExternalName = Box<ExternalName>;
/// The main entry point for lowering with ISLE.
pub(crate) fn lower<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
outputs: &[InsnOutput],
Expand All @@ -48,9 +50,15 @@ pub(crate) fn lower<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
generated_code::constructor_lower(cx, insn)
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
outputs,
inst,
|cx, insn| generated_code::constructor_lower(cx, insn),
)
}

pub struct ExtendedValue {
Expand Down
4 changes: 3 additions & 1 deletion cranelift/codegen/src/isa/aarch64/lower_inst.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ use crate::{CodegenError, CodegenResult};
use alloc::boxed::Box;
use alloc::vec::Vec;
use core::convert::TryFrom;
use target_lexicon::Triple;

/// Actually codegen an instruction's results into registers.
pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
ctx: &mut C,
insn: IRInst,
triple: &Triple,
flags: &Flags,
isa_flags: &aarch64_settings::Flags,
) -> CodegenResult<()> {
Expand All @@ -33,7 +35,7 @@ pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
None
};

if let Ok(()) = super::lower::isle::lower(ctx, flags, isa_flags, &outputs, insn) {
if let Ok(()) = super::lower::isle::lower(ctx, triple, flags, isa_flags, &outputs, insn) {
return Ok(());
}

Expand Down
12 changes: 9 additions & 3 deletions cranelift/codegen/src/isa/s390x/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ impl LowerBackend for S390xBackend {
None
};

if let Ok(()) =
super::lower::isle::lower(ctx, &self.flags, &self.isa_flags, &outputs, ir_inst)
{
if let Ok(()) = super::lower::isle::lower(
ctx,
&self.triple,
&self.flags,
&self.isa_flags,
&outputs,
ir_inst,
) {
return Ok(());
}

Expand Down Expand Up @@ -295,6 +300,7 @@ impl LowerBackend for S390xBackend {
// the second branch (if any) by emitting a two-way conditional branch.
if let Ok(()) = super::lower::isle::lower_branch(
ctx,
&self.triple,
&self.flags,
&self.isa_flags,
branches[0],
Expand Down
27 changes: 21 additions & 6 deletions cranelift/codegen/src/isa/s390x/lower/isle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use std::boxed::Box;
use std::cell::Cell;
use std::convert::TryFrom;
use std::vec::Vec;
use target_lexicon::Triple;

type BoxCallInfo = Box<CallInfo>;
type BoxCallIndInfo = Box<CallIndInfo>;
Expand All @@ -37,6 +38,7 @@ type VecMInstBuilder = Cell<Vec<MInst>>;
/// The main entry point for lowering with ISLE.
pub(crate) fn lower<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
outputs: &[InsnOutput],
Expand All @@ -45,14 +47,21 @@ pub(crate) fn lower<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
generated_code::constructor_lower(cx, insn)
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
outputs,
inst,
|cx, insn| generated_code::constructor_lower(cx, insn),
)
}

/// The main entry point for branch lowering with ISLE.
pub(crate) fn lower_branch<C>(
lower_ctx: &mut C,
triple: &Triple,
flags: &Flags,
isa_flags: &IsaFlags,
branch: Inst,
Expand All @@ -61,9 +70,15 @@ pub(crate) fn lower_branch<C>(
where
C: LowerCtx<I = MInst>,
{
lower_common(lower_ctx, flags, isa_flags, &[], branch, |cx, insn| {
generated_code::constructor_lower_branch(cx, insn, &targets.to_vec())
})
lower_common(
lower_ctx,
triple,
flags,
isa_flags,
&[],
branch,
|cx, insn| generated_code::constructor_lower_branch(cx, insn, &targets.to_vec()),
)
}

impl<C> generated_code::Context for IsleContext<'_, C, Flags, IsaFlags, 6>
Expand Down
10 changes: 10 additions & 0 deletions cranelift/codegen/src/isa/x64/inst.isle
Original file line number Diff line number Diff line change
Expand Up @@ -3354,3 +3354,13 @@
(decl x64_rsp () Reg)
(rule (x64_rsp)
(mov_preg (preg_rsp)))

;;;; Helpers for Emitting LibCalls ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(type LibCall extern
(enum
FmaF32
FmaF64))

(decl libcall_3 (LibCall Reg Reg Reg) Reg)
(extern constructor libcall_3 libcall_3)
4 changes: 4 additions & 0 deletions cranelift/codegen/src/isa/x64/lower.isle
Original file line number Diff line number Diff line change
Expand Up @@ -2491,6 +2491,10 @@

;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;

(rule (lower (has_type $F32 (fma x y z)))
(libcall_3 (LibCall.FmaF32) x y z))
(rule (lower (has_type $F64 (fma x y z)))
(libcall_3 (LibCall.FmaF64) x y z))
(rule (lower (has_type $F32X4 (fma x y z)))
(x64_vfmadd213ps x y z))
(rule (lower (has_type $F64X2 (fma x y z)))
Expand Down
42 changes: 13 additions & 29 deletions cranelift/codegen/src/isa/x64/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ pub(super) mod isle;
use crate::data_value::DataValue;
use crate::ir::{
condcodes::{CondCode, FloatCC, IntCC},
types, AbiParam, ExternalName, Inst as IRInst, InstructionData, LibCall, Opcode, Signature,
Type,
types, ExternalName, Inst as IRInst, InstructionData, LibCall, Opcode, Type,
};
use crate::isa::x64::abi::*;
use crate::isa::x64::inst::args::*;
Expand Down Expand Up @@ -573,29 +572,13 @@ fn emit_fcmp<C: LowerCtx<I = Inst>>(
cond_result
}

fn make_libcall_sig<C: LowerCtx<I = Inst>>(
ctx: &mut C,
insn: IRInst,
call_conv: CallConv,
) -> Signature {
let mut sig = Signature::new(call_conv);
for i in 0..ctx.num_inputs(insn) {
sig.params.push(AbiParam::new(ctx.input_ty(insn, i)));
}
for i in 0..ctx.num_outputs(insn) {
sig.returns.push(AbiParam::new(ctx.output_ty(insn, i)));
}
sig
}

fn emit_vm_call<C: LowerCtx<I = Inst>>(
ctx: &mut C,
flags: &Flags,
triple: &Triple,
libcall: LibCall,
insn: IRInst,
inputs: SmallVec<[InsnInput; 4]>,
outputs: SmallVec<[InsnOutput; 2]>,
inputs: &[Reg],
outputs: &[Writable<Reg>],
) -> CodegenResult<()> {
let extname = ExternalName::LibCall(libcall);

Expand All @@ -607,7 +590,7 @@ fn emit_vm_call<C: LowerCtx<I = Inst>>(

// TODO avoid recreating signatures for every single Libcall function.
let call_conv = CallConv::for_libcall(flags, CallConv::triple_default(triple));
let sig = make_libcall_sig(ctx, insn, call_conv);
let sig = libcall.signature(call_conv);
let caller_conv = ctx.abi().call_conv();

let mut abi = X64ABICaller::from_func(&sig, &extname, dist, caller_conv, flags)?;
Expand All @@ -617,14 +600,12 @@ fn emit_vm_call<C: LowerCtx<I = Inst>>(
assert_eq!(inputs.len(), abi.num_args());

for (i, input) in inputs.iter().enumerate() {
let arg_reg = put_input_in_reg(ctx, *input);
abi.emit_copy_regs_to_arg(ctx, i, ValueRegs::one(arg_reg));
abi.emit_copy_regs_to_arg(ctx, i, ValueRegs::one(*input));
}

abi.emit_call(ctx);
for (i, output) in outputs.iter().enumerate() {
let retval_reg = get_output_reg(ctx, *output).only_reg().unwrap();
abi.emit_copy_retval_to_regs(ctx, i, ValueRegs::one(retval_reg));
abi.emit_copy_retval_to_regs(ctx, i, ValueRegs::one(*output));
}
abi.emit_stack_post_adjust(ctx);

Expand Down Expand Up @@ -810,7 +791,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
None
};

if let Ok(()) = isle::lower(ctx, flags, isa_flags, &outputs, insn) {
if let Ok(()) = isle::lower(ctx, triple, flags, isa_flags, &outputs, insn) {
return Ok(());
}

Expand Down Expand Up @@ -884,6 +865,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
| Opcode::FvpromoteLow
| Opcode::Fdemote
| Opcode::Fvdemote
| Opcode::Fma
| Opcode::Icmp
| Opcode::Fcmp
| Opcode::Load
Expand Down Expand Up @@ -1974,7 +1956,11 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
ty, op
),
};
emit_vm_call(ctx, flags, triple, libcall, insn, inputs, outputs)?;

let input = put_input_in_reg(ctx, inputs[0]);
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();

emit_vm_call(ctx, flags, triple, libcall, &[input], &[dst])?;
}
}

Expand Down Expand Up @@ -2726,8 +2712,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(

Opcode::Cls => unimplemented!("Cls not supported"),

Opcode::Fma => implemented_in_isle(ctx),

Opcode::BorNot | Opcode::BxorNot => {
unimplemented!("or-not / xor-not opcodes not implemented");
}
Expand Down
Loading

0 comments on commit 709716b

Please sign in to comment.