Skip to content

Commit ca79220

Browse files
Add printf support
1 parent e58d244 commit ca79220

File tree

5 files changed

+140
-4
lines changed

5 files changed

+140
-4
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
1010
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
11+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1112
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
13+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1214
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1315
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1416
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

src/KernelAbstractions.jl

+106-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print, @printf
55
export Device, GPU, CPU, CUDADevice, Event, MultiEvent, NoneEvent
66
export async_copy!
77

88

99
using MacroTools
10+
using Printf
1011
using StaticArrays
1112
using Cassette
1213
using Adapt
@@ -28,6 +29,7 @@ and then invoked on the arguments.
2829
- [`@uniform`](@ref)
2930
- [`@synchronize`](@ref)
3031
- [`@print`](@ref)
32+
- [`@printf`](@ref)
3133
3234
# Example:
3335
@@ -236,6 +238,37 @@ macro print(items...)
236238
end
237239
end
238240

241+
@generated function promote_c_argument(arg)
242+
# > When a function with a variable-length argument list is called, the variable
243+
# > arguments are passed using C's old ``default argument promotions.'' These say that
244+
# > types char and short int are automatically promoted to int, and type float is
245+
# > automatically promoted to double. Therefore, varargs functions will never receive
246+
# > arguments of type char, short int, or float.
247+
248+
if arg == Cchar || arg == Cshort
249+
return :(Cint(arg))
250+
elseif arg == Cfloat
251+
return :(Cdouble(arg))
252+
else
253+
return :(arg)
254+
end
255+
end
256+
257+
"""
258+
@printf(fmt::String, args...)
259+
260+
This is a unified formatted print statement.
261+
262+
# Platform differences
263+
- `GPU`: This will reorganize the items to print via @cuprintf
264+
- `CPU`: This will call `print(items...)`
265+
"""
266+
macro printf(fmt::String, args...)
267+
fmt_val = Val(Symbol(fmt))
268+
269+
return :(__printf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
270+
end
271+
239272
"""
240273
@index
241274
@@ -452,6 +485,78 @@ end
452485
end
453486
end
454487

488+
# Results in "Conversion of boxed type String is not allowed"
489+
# @generated function __printf(::Val{fmt}, argspec...) where {fmt}
490+
# arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
491+
# arg_types = [argspec...]
492+
493+
# T_void = LLVM.VoidType(LLVM.Interop.JuliaContext())
494+
# T_int32 = LLVM.Int32Type(LLVM.Interop.JuliaContext())
495+
# T_pint8 = LLVM.PointerType(LLVM.Int8Type(LLVM.Interop.JuliaContext()))
496+
497+
# # create functions
498+
# param_types = LLVMType[convert.(LLVMType, arg_types)...]
499+
# llvm_f, _ = create_function(T_int32, param_types)
500+
# mod = LLVM.parent(llvm_f)
501+
# sfmt = String(fmt)
502+
# # generate IR
503+
# Builder(LLVM.Interop.JuliaContext()) do builder
504+
# entry = BasicBlock(llvm_f, "entry", LLVM.Interop.JuliaContext())
505+
# position!(builder, entry)
506+
507+
# str = globalstring_ptr!(builder, sfmt)
508+
509+
# # construct and fill args buffer
510+
# if isempty(argspec)
511+
# buffer = LLVM.PointerNull(T_pint8)
512+
# else
513+
# argtypes = LLVM.StructType("printf_args", LLVM.Interop.JuliaContext())
514+
# elements!(argtypes, param_types)
515+
516+
# args = alloca!(builder, argtypes)
517+
# for (i, param) in enumerate(parameters(llvm_f))
518+
# p = struct_gep!(builder, args, i-1)
519+
# store!(builder, param, p)
520+
# end
521+
522+
# buffer = bitcast!(builder, args, T_pint8)
523+
# end
524+
525+
# # invoke vprintf and return
526+
# vprintf_typ = LLVM.FunctionType(T_int32, [T_pint8, T_pint8])
527+
# vprintf = LLVM.Function(mod, "vprintf", vprintf_typ)
528+
# chars = call!(builder, vprintf, [str, buffer])
529+
530+
# ret!(builder, chars)
531+
# end
532+
533+
# arg_tuple = Expr(:tuple, arg_exprs...)
534+
# call_function(llvm_f, Int32, Tuple{arg_types...}, arg_tuple)
535+
# end
536+
537+
# Results in "InvalidIRError: compiling kernel
538+
# gpu_kernel_printf(... Reason: unsupported dynamic
539+
# function invocation"
540+
@generated function __printf(::Val{fmt}, items...) where {fmt}
541+
str = ""
542+
args = []
543+
544+
for i in 1:length(items)
545+
item = :(items[$i])
546+
T = items[i]
547+
if T <: Val
548+
item = QuoteNode(T.parameters[1])
549+
end
550+
push!(args, item)
551+
end
552+
sfmt = String(fmt)
553+
quote
554+
# @sprintf($sfmt, $(args...))
555+
@print(@sprintf($sfmt, $(args...)))
556+
# @print("test")
557+
end
558+
end
559+
455560
###
456561
# Backends/Implementation
457562
###

src/backends/cpu.jl

+4
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ end
208208
__print(items...)
209209
end
210210

211+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__printf), fmt, items...)
212+
__printf(fmt, items...)
213+
end
214+
211215
generate_overdubs(CPUCtx)
212216

213217
# Don't recurse into these functions

src/backends/cuda.jl

+4
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ end
319319
CUDA._cuprint(args...)
320320
end
321321

322+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), fmt, args...)
323+
CUDA._cuprintf(fmt, args...)
324+
end
325+
322326
###
323327
# GPU implementation of const memory
324328
###

test/print_test.jl

+24-3
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,46 @@ if has_cuda_gpu()
55
CUDA.allowscalar(false)
66
end
77

8+
struct Foo{A,B} end
9+
810
@kernel function kernel_print()
911
I = @index(Global)
1012
@print("Hello from thread ", I, "!\n")
1113
end
1214

15+
@kernel function kernel_printf()
16+
I = @index(Global)
17+
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, nameof(Foo))
18+
end
19+
1320
function test_print(backend)
1421
kernel = kernel_print(backend, 4)
15-
kernel(ndrange=(4,))
22+
kernel(ndrange=(4,))
23+
end
24+
25+
function test_printf(backend)
26+
kernel = kernel_printf(backend, 4)
27+
kernel(ndrange=(4,))
1628
end
1729

1830
@testset "print test" begin
31+
wait(test_print(CPU()))
32+
@test true
33+
34+
wait(test_printf(CPU()))
35+
@test true
36+
1937
if has_cuda_gpu()
2038
wait(test_print(CUDADevice()))
2139
@test true
40+
wait(test_printf(CUDADevice()))
41+
@test true
2242
end
2343

24-
wait(test_print(CPU()))
44+
@print("Why this should work")
2545
@test true
2646

27-
@print("Why this should work")
47+
@printf("Why this should work")
2848
@test true
2949
end
50+

0 commit comments

Comments
 (0)