Skip to content

Commit 3784d67

Browse files
authored
Try #163:
2 parents 4ec8066 + 4a3acad commit 3784d67

File tree

5 files changed

+142
-4
lines changed

5 files changed

+142
-4
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
99
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
1010
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
11+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1112
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
13+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1214
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1315
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
1416
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

src/KernelAbstractions.jl

+99-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
module KernelAbstractions
22

33
export @kernel
4-
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print
4+
export @Const, @localmem, @private, @uniform, @synchronize, @index, groupsize, @print, @printf
55
export Device, GPU, CPU, CUDADevice, Event, MultiEvent, NoneEvent
66
export async_copy!
77

88

99
using MacroTools
10+
using Printf
1011
using StaticArrays
1112
using Cassette
1213
using Adapt
@@ -28,6 +29,7 @@ and then invoked on the arguments.
2829
- [`@uniform`](@ref)
2930
- [`@synchronize`](@ref)
3031
- [`@print`](@ref)
32+
- [`@printf`](@ref)
3133
3234
# Example:
3335
@@ -236,6 +238,32 @@ macro print(items...)
236238
end
237239
end
238240

241+
# When a function with a variable-length argument list is called, the variable
242+
# arguments are passed using C's old ``default argument promotions.'' These say that
243+
# types char and short int are automatically promoted to int, and type float is
244+
# automatically promoted to double. Therefore, varargs functions will never receive
245+
# arguments of type char, short int, or float.
246+
247+
promote_c_argument(arg) = arg
248+
promote_c_argument(arg::Cfloat) = Cdouble(arg)
249+
promote_c_argument(arg::Cchar) = Cint(arg)
250+
promote_c_argument(arg::Cshort) = Cint(arg)
251+
252+
"""
253+
@printf(fmt::String, args...)
254+
255+
This is a unified formatted printf statement.
256+
257+
# Platform differences
258+
- `GPU`: This will reorganize the items to print via @cuprintf
259+
- `CPU`: This will call `sprintf(fmt, items...)`
260+
"""
261+
macro printf(fmt::String, args...)
262+
fmt_val = Val(Symbol(fmt))
263+
264+
return :(__printf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...)))
265+
end
266+
239267
"""
240268
@index
241269
@@ -452,6 +480,76 @@ end
452480
end
453481
end
454482

483+
# Results in "Conversion of boxed type String is not allowed"
484+
# @generated function __printf(::Val{fmt}, argspec...) where {fmt}
485+
# arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
486+
# arg_types = [argspec...]
487+
488+
# T_void = LLVM.VoidType(LLVM.Interop.JuliaContext())
489+
# T_int32 = LLVM.Int32Type(LLVM.Interop.JuliaContext())
490+
# T_pint8 = LLVM.PointerType(LLVM.Int8Type(LLVM.Interop.JuliaContext()))
491+
492+
# # create functions
493+
# param_types = LLVMType[convert.(LLVMType, arg_types)...]
494+
# llvm_f, _ = create_function(T_int32, param_types)
495+
# mod = LLVM.parent(llvm_f)
496+
# sfmt = String(fmt)
497+
# # generate IR
498+
# Builder(LLVM.Interop.JuliaContext()) do builder
499+
# entry = BasicBlock(llvm_f, "entry", LLVM.Interop.JuliaContext())
500+
# position!(builder, entry)
501+
502+
# str = globalstring_ptr!(builder, sfmt)
503+
504+
# # construct and fill args buffer
505+
# if isempty(argspec)
506+
# buffer = LLVM.PointerNull(T_pint8)
507+
# else
508+
# argtypes = LLVM.StructType("printf_args", LLVM.Interop.JuliaContext())
509+
# elements!(argtypes, param_types)
510+
511+
# args = alloca!(builder, argtypes)
512+
# for (i, param) in enumerate(parameters(llvm_f))
513+
# p = struct_gep!(builder, args, i-1)
514+
# store!(builder, param, p)
515+
# end
516+
517+
# buffer = bitcast!(builder, args, T_pint8)
518+
# end
519+
520+
# # invoke vprintf and return
521+
# vprintf_typ = LLVM.FunctionType(T_int32, [T_pint8, T_pint8])
522+
# vprintf = LLVM.Function(mod, "vprintf", vprintf_typ)
523+
# chars = call!(builder, vprintf, [str, buffer])
524+
525+
# ret!(builder, chars)
526+
# end
527+
528+
# arg_tuple = Expr(:tuple, arg_exprs...)
529+
# call_function(llvm_f, Int32, Tuple{arg_types...}, arg_tuple)
530+
# end
531+
532+
# Results in "InvalidIRError: compiling kernel
533+
# gpu_kernel_printf(... Reason: unsupported dynamic
534+
# function invocation"
535+
@generated function __printf(::Val{fmt}, items...) where {fmt}
536+
str = ""
537+
args = []
538+
539+
for i in 1:length(items)
540+
item = :(items[$i])
541+
T = items[i]
542+
if T <: Val
543+
item = QuoteNode(T.parameters[1])
544+
end
545+
push!(args, item)
546+
end
547+
sfmt = String(fmt)
548+
quote
549+
Printf.@printf($sfmt, $(args...))
550+
end
551+
end
552+
455553
###
456554
# Backends/Implementation
457555
###

src/backends/cpu.jl

+4
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,10 @@ end
208208
__print(items...)
209209
end
210210

211+
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(__printf), fmt, items...)
212+
__printf(fmt, items...)
213+
end
214+
211215
generate_overdubs(CPUCtx)
212216

213217
# Don't recurse into these functions

src/backends/cuda.jl

+8
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,14 @@ end
319319
CUDA._cuprint(args...)
320320
end
321321

322+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), fmt, args...)
323+
CUDA._cuprintf(Val(fmt), args...)
324+
end
325+
326+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(__printf), ::Val{fmt}, args...) where fmt
327+
CUDA._cuprintf(Val(fmt), args...)
328+
end
329+
322330
###
323331
# GPU implementation of const memory
324332
###

test/print_test.jl

+29-3
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,51 @@ if has_cuda_gpu()
55
CUDA.allowscalar(false)
66
end
77

8+
struct Foo{A,B} end
9+
get_name(::Type{T}) where T<:Foo = "Foo"
10+
811
@kernel function kernel_print()
912
I = @index(Global)
1013
@print("Hello from thread ", I, "!\n")
1114
end
1215

16+
@kernel function kernel_printf()
17+
I = @index(Global)
18+
# @printf("Hello printf %s thread %d! type = %s.\n", "from", I, nameof(Foo))
19+
# @print("Hello printf from thread ", I, "!\n")
20+
# @printf("Hello printf %s thread %d! type = %s.\n", "from", I, string(nameof(Foo)))
21+
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, "Foo")
22+
@printf("Hello printf %s thread %d! type = %s.\n", "from", I, get_name(Foo))
23+
end
24+
1325
function test_print(backend)
1426
kernel = kernel_print(backend, 4)
15-
kernel(ndrange=(4,))
27+
kernel(ndrange=(4,))
28+
end
29+
30+
function test_printf(backend)
31+
kernel = kernel_printf(backend, 4)
32+
kernel(ndrange=(4,))
1633
end
1734

1835
@testset "print test" begin
36+
wait(test_print(CPU()))
37+
@test true
38+
39+
wait(test_printf(CPU()))
40+
@test true
41+
1942
if has_cuda_gpu()
2043
wait(test_print(CUDADevice()))
2144
@test true
45+
wait(test_printf(CUDADevice()))
46+
@test true
2247
end
2348

24-
wait(test_print(CPU()))
49+
@print("Why this should work")
2550
@test true
2651

27-
@print("Why this should work")
52+
@printf("Why this should work")
2853
@test true
2954
end
55+

0 commit comments

Comments
 (0)