Skip to content

Commit

Permalink
llvm: Control debug output via PNL_LLVM_DEBUG env var (#3092)
Browse files Browse the repository at this point in the history
Use per jit-engine printf pointer to not crash printf calls on GPU.
Replace 'overrride_debug' with tags and allow dynamic control via PNL_LLVM_DEBUG env var.
  • Loading branch information
jvesely authored Oct 31, 2024
2 parents bfe4461 + 2998389 commit 817ca23
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 112 deletions.
6 changes: 3 additions & 3 deletions psyneulink/core/llvm/builder_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ def module_count():


_BUILTIN_PREFIX = "__pnl_builtin_"
_builtin_intrinsics = frozenset(('pow', 'log', 'exp', 'tanh', 'coth', 'csch',
'sin', 'cos',
_builtin_intrinsics = frozenset(('pow', 'log', 'exp', 'tanh', 'coth', 'csch', 'sin', 'cos',
'is_close_float', 'is_close_double',
'mt_rand_init', 'philox_rand_init'))
'mt_rand_init', 'philox_rand_init',
'get_printf_address'))


class _node_assembly():
Expand Down
39 changes: 31 additions & 8 deletions psyneulink/core/llvm/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

# ********************************************* PNL LLVM builtins **************************************************************

from llvmlite import ir

from ctypes import util
from llvmlite import ir, binding
import sys

from . import helpers
from .builder_context import LLVMBuilderContext, _BUILTIN_PREFIX
Expand Down Expand Up @@ -469,18 +470,39 @@ def setup_pnl_intrinsics(ctx):
ir.Function(ctx.module, single_intr_ty, name=_BUILTIN_PREFIX + "log")
ir.Function(ctx.module, double_intr_ty, name=_BUILTIN_PREFIX + "pow")

# printf address
ir.Function(ctx.module, ir.FunctionType(ir.IntType(64), []), name=_BUILTIN_PREFIX + "get_printf_address")


def _generate_intrinsic_wrapper(module, name, ret, args):
intrinsic = module.declare_intrinsic("llvm." + name, list(set(args)))

def _generate_new_function(module, name, ret, args):
func_ty = ir.FunctionType(ret, args)
function = ir.Function(module, func_ty, name=_BUILTIN_PREFIX + name)
function = ir.Function(module, func_ty, name=name)
function.attributes.add('alwaysinline')
block = function.append_basic_block(name="entry")
builder = ir.IRBuilder(block)
builder.debug_metadata = LLVMBuilderContext.get_debug_location(function, None)
builder.ret(builder.call(intrinsic, function.args))

return builder

def _generate_intrinsic_wrapper(module, name, ret, args):
intrinsic = module.declare_intrinsic("llvm." + name, list(set(args)))

builder = _generate_new_function(module, _BUILTIN_PREFIX + name, ret, args)
intrinsic_result = builder.call(intrinsic, builder.block.function.args)
builder.ret(intrinsic_result)

def _generate_get_printf_address(module):
builder = _generate_new_function(module, _BUILTIN_PREFIX + "get_printf_address", ir.IntType(64), [])

libc_name = "msvcrt" if sys.platform == "win32" else "c"
libc = util.find_library(libc_name)
assert libc is not None, "Standard libc library not found"

binding.load_library_permanently(libc)
# Address will be none if the symbol is not found
printf_address = binding.address_of_symbol("printf")
assert printf_address is not None, "'printf' symbol not found in {}".format(libc)

builder.ret(ir.IntType(64)(printf_address))

def _generate_cpu_builtins_module(_float_ty):
"""Generate function wrappers for log, exp, and pow intrinsics."""
Expand All @@ -489,6 +511,7 @@ def _generate_cpu_builtins_module(_float_ty):
_generate_intrinsic_wrapper(module, intrinsic, _float_ty, [_float_ty])

_generate_intrinsic_wrapper(module, "pow", _float_ty, [_float_ty, _float_ty])
_generate_get_printf_address(module)
return module


Expand Down
2 changes: 1 addition & 1 deletion psyneulink/core/llvm/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
* "stat" -- prints code generation and compilation statistics
* "time_stat" -- print compilation and code generation times
* "comp_node_debug" -- print intermediate results after execution composition node wrapper.
* "print_values" -- Enabled printfs in llvm code (from ctx printf helper)
* "printf_tags" -- Enabledprintfs in compiled caode with the specififed tags
Compilation modifiers:
* "const_data" -- hardcode initial output values into generated code,
Expand Down
66 changes: 32 additions & 34 deletions psyneulink/core/llvm/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@

# ********************************************* PNL LLVM helpers **************************************************************

import ast
from contextlib import contextmanager
from ctypes import util
import warnings
import sys

from llvmlite import ir
import llvmlite.binding as llvm


from .debug import debug_env
Expand Down Expand Up @@ -401,55 +399,55 @@ def call_elementwise_operation(ctx, builder, x, operation, output_ptr):
for (inp_ptr, out_ptr) in recursive_iterate_arrays(ctx, builder, x, output_ptr):
builder.store(operation(ctx, builder, builder.load(inp_ptr)), out_ptr)

def printf(builder, fmt, *args, override_debug=False):
if "print_values" not in debug_env and not override_debug:
return

#FIXME: Fix builtin printf and use that instead of this
libc_name = "msvcrt" if sys.platform == "win32" else "c"
libc = util.find_library(libc_name)
assert libc is not None, "Standard libc library not found"

llvm.load_library_permanently(libc)
# Address will be none if the symbol is not found
printf_address = llvm.address_of_symbol("printf")
assert printf_address is not None, "'printf' symbol not found in {}".format(libc)
def printf(ctx, builder, fmt, *args, tags:set):

# Direct pointer constants don't work
printf_ty = ir.FunctionType(ir.IntType(32), [ir.IntType(8).as_pointer()], var_arg=True)
printf = builder.inttoptr(ir.IntType(64)(printf_address), printf_ty.as_pointer())
ir_module = builder.function.module
fmt += "\0"
tags = frozenset(tags)
user_tags = frozenset(ast.literal_eval(debug_env.get("printf_tags", "[]")))
if "all" not in user_tags and "always" not in tags and not tags.intersection(user_tags):
return

# Set up the formatting string as global symbol
int8 = ir.IntType(8)
fmt_data = bytearray(fmt.encode("utf8"))
fmt_data = bytearray((fmt + "\0").encode("utf8"))
fmt_ty = ir.ArrayType(int8, len(fmt_data))
global_fmt = ir.GlobalVariable(ir_module, fmt_ty,

ir_module = builder.function.module
global_fmt = ir.GlobalVariable(ir_module,
fmt_ty,
name="printf_fmt_" + str(len(ir_module.globals)))
global_fmt.linkage = "internal"
global_fmt.global_constant = True
global_fmt.initializer = fmt_ty(fmt_data)

fmt_ptr = builder.gep(global_fmt, [ir.IntType(32)(0), ir.IntType(32)(0)])
conv_args = [builder.fpext(a, ir.DoubleType()) if is_floating_point(a) else a for a in args]
builder.call(printf, [fmt_ptr] + conv_args)
printf_ty = ir.FunctionType(ir.IntType(32), [ir.IntType(8).as_pointer()], var_arg=True)
get_printf_addr_f = ctx.get_builtin("get_printf_address", [])
printf_address = builder.call(get_printf_addr_f, [])

printf_is_not_null = builder.icmp_unsigned("!=", printf_address, printf_address.type(0))
with builder.if_then(printf_is_not_null, likely=True):
printf_f = builder.inttoptr(printf_address, printf_ty.as_pointer())

fmt_ptr = builder.gep(global_fmt, [ir.IntType(32)(0), ir.IntType(32)(0)])
conv_args = [builder.fpext(a, ir.DoubleType()) if is_floating_point(a) else a for a in args]
builder.call(printf_f, [fmt_ptr] + conv_args)


def printf_float_array(builder, array, prefix="", suffix="\n", override_debug=False):
printf(builder, prefix, override_debug=override_debug)
def printf_float_array(ctx, builder, array, prefix="", suffix="\n", *, tags:set):
printf(ctx, builder, prefix, tags=tags)

with array_ptr_loop(builder, array, "print_array_loop") as (b1, i):
printf(b1, "%lf ", b1.load(b1.gep(array, [i.type(0), i])), override_debug=override_debug)
printf(ctx, b1, "%lf ", b1.load(b1.gep(array, [i.type(0), i])), tags=tags)

printf(builder, suffix, override_debug=override_debug)
printf(ctx, builder, suffix, tags=tags)


def printf_float_matrix(builder, matrix, prefix="", suffix="\n", override_debug=False):
printf(builder, prefix, override_debug=override_debug)
def printf_float_matrix(ctx, builder, matrix, prefix="", suffix="\n", *, tags:set):
printf(ctx, builder, prefix, tags=tags)
with array_ptr_loop(builder, matrix, "print_row_loop") as (b1, i):
row = b1.gep(matrix, [i.type(0), i])
printf_float_array(b1, row, suffix="\n", override_debug=override_debug)
printf(builder, suffix, override_debug=override_debug)
printf_float_array(ctx, b1, row, suffix="\n", tags=tags)

printf(ctx, builder, suffix, tags=tags)


class ConditionGenerator:
Expand Down
2 changes: 2 additions & 0 deletions psyneulink/core/llvm/jit_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,12 +280,14 @@ def _init(self):
self._jit_engine.set_object_cache(self._object_cache)


# FIXME: Get device side printf pointer
_ptx_builtin_source = """
__device__ {type} __pnl_builtin_sin({type} a) {{ return sin(a); }}
__device__ {type} __pnl_builtin_cos({type} a) {{ return cos(a); }}
__device__ {type} __pnl_builtin_log({type} a) {{ return log(a); }}
__device__ {type} __pnl_builtin_exp({type} a) {{ return exp(a); }}
__device__ {type} __pnl_builtin_pow({type} a, {type} b) {{ return pow(a, b); }}
__device__ int64_t __pnl_builtin_get_printf_address() {{ return 0; }}
"""


Expand Down
36 changes: 22 additions & 14 deletions psyneulink/library/compositions/compiledoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def step(self, ctx):
t = builder.gep(optim_struct, [zero, ctx.int32_ty(self._T_NUM)])

# get methods needed
pow = ctx.import_llvm_function("__pnl_builtin_pow")
sqrt = ctx.get_builtin("sqrt", [ctx.float_ty])
pow_f = ctx.import_llvm_function("__pnl_builtin_pow")
sqrt_f = ctx.get_builtin("sqrt", [ctx.float_ty])

lr = ctx.float_ty(self.lr)
eps = ctx.float_ty(self.eps)
Expand All @@ -120,13 +120,12 @@ def step(self, ctx):
builder.store(builder.fadd(builder.load(t), one_float), t)
t_val = builder.load(t)
# 1.5) calculate values to be used later (based on incremented t)
b1_pow = builder.call(pow, [b1, t_val])
b2_pow = builder.call(pow, [b2, t_val])
b1_pow = builder.call(pow_f, [b1, t_val])
b2_pow = builder.call(pow_f, [b2, t_val])
one_minus_b1_pow = builder.fsub(one_float, b1_pow)
one_minus_b2_pow = builder.fsub(one_float, b2_pow)

pnlvm.helpers.printf(
builder, f"%f b1_pow_sub %f\nb2 pow sub %f\n",t_val, one_minus_b1_pow, one_minus_b2_pow)
pnlvm.helpers.printf(ctx, builder, f"%f b1_pow_sub %f\nb2 pow sub %f\n",t_val, one_minus_b1_pow, one_minus_b2_pow, tags={"torch"})

# 2) update first moments
for idx, proj in enumerate(self._pytorch_model.projection_wrappers):
Expand All @@ -144,7 +143,11 @@ def step(self, ctx):
# m_t = m_t + (1-b1)*g_t
gen_inject_mat_add(ctx, builder, m_t_ptr, tmp_val, m_t_ptr)

pnlvm.helpers.printf_float_matrix(builder, m_t_ptr, prefix=f"mt val: {proj.sender._mechanism} -> {proj.receiver._mechanism}\n", override_debug=False)
pnlvm.helpers.printf_float_matrix(ctx,
builder,
m_t_ptr,
prefix=f"mt val: {proj.sender._mechanism} -> {proj.receiver._mechanism}\n",
tags={"torch"})
# 3) update second moments
for idx, proj in enumerate(self._pytorch_model.projection_wrappers):
proj_idx_ir = ctx.int32_ty(idx)
Expand Down Expand Up @@ -180,24 +183,29 @@ def step(self, ctx):
delta_w_ptr = builder.gep(
delta_w, [zero, proj_idx_ir])

pnlvm.helpers.printf_float_matrix(builder, delta_w_ptr, prefix=f"grad val: {proj.sender._mechanism} -> {proj.receiver._mechanism}\n", override_debug=False)
pnlvm.helpers.printf_float_matrix(ctx,
builder,
delta_w_ptr,
prefix=f"grad val: {proj.sender._mechanism} -> {proj.receiver._mechanism}\n",
tags={"torch"})

# this is messy - #TODO - cleanup this
weights_llvmlite = proj._extract_llvm_matrix(ctx, builder, state, params)
dim_x, dim_y = proj.matrix.shape

weight_row = None
pnlvm.helpers.printf(builder, "biascorr2 %.20f\n", one_minus_b2_pow, override_debug=False)
pnlvm.helpers.printf(ctx, builder, "biascorr2 %.20f\n", one_minus_b2_pow, tags={"torch"})
with pnlvm.helpers.for_loop_zero_inc(builder, ctx.int32_ty(dim_x), "optimizer_w_upd_outer") as (b1, weight_row):
weight_column = None
with pnlvm.helpers.for_loop_zero_inc(b1, ctx.int32_ty(dim_y), "optimizer_w_upd_inner") as (b2, weight_column):
# sqrt(v_t) + eps
v_t_value = b2.load(b2.gep(v_t_ptr, [zero, weight_row, weight_column]))
value = b2.call(sqrt, [v_t_value])
denom = b2.call(sqrt, [one_minus_b2_pow])
value = b2.call(sqrt_f, [v_t_value])
denom = b2.call(sqrt_f, [one_minus_b2_pow])
value = b2.fdiv(value, denom)
value = b2.fadd(value, eps)
pnlvm.helpers.printf(builder, "val %.20f\n", value, override_debug=False)
pnlvm.helpers.printf(ctx, builder, "val %.20f\n", value, tags={"torch"})

# alpha_t * m_t
m_t_value = b2.load(b2.gep(
m_t_ptr, [zero, weight_row, weight_column]))
Expand All @@ -213,9 +221,9 @@ def step(self, ctx):
value = b2.fadd(b2.load(old_weight_ptr), value)
b2.store(value, old_weight_ptr)

pnlvm.helpers.printf(b1, "\n", override_debug=False)
pnlvm.helpers.printf(ctx, b1, "\n", tags={"torch"})

pnlvm.helpers.printf(builder, f"\t\t\tOPTIM DONE UPDATE\n",override_debug=False)
pnlvm.helpers.printf(ctx, builder, f"\t\t\tOPTIM DONE UPDATE\n", tags={"torch"})

builder.ret_void()

Expand Down
Loading

0 comments on commit 817ca23

Please sign in to comment.