Skip to content

Commit 07a3da9

Browse files
authored
Emit select for out-of-range builtin var indices (#3023)
The behaviour for out-of-range dimension arguments to work-item functions is well defined in OpenCL C. For example, `get_global_size` must return 1 if its argument is larger than `get_work_dim() - 1`. Ensure the generated `extractelement` index never exceeds the vector size and return the correct out-of-range value (which is either 0 or 1 depending on the builtin). Fixes #2638 .
1 parent 4eea290 commit 07a3da9

File tree

3 files changed

+72
-2
lines changed

3 files changed

+72
-2
lines changed

lib/SPIRV/SPIRVUtil.cpp

+32-1
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,21 @@ static Type *parsePrimitiveType(LLVMContext &Ctx, StringRef Name) {
606606

607607
} // namespace SPIRV
608608

609+
namespace {
610+
611+
// Return the value for when the dimension index of a builtin is out of range.
612+
uint64_t getBuiltinOutOfRangeValue(StringRef VarName) {
613+
assert(VarName.starts_with("__spirv_BuiltIn"));
614+
return StringSwitch<uint64_t>(VarName)
615+
.EndsWith("GlobalSize", 1)
616+
.EndsWith("NumWorkgroups", 1)
617+
.EndsWith("WorkgroupSize", 1)
618+
.EndsWith("EnqueuedWorkgroupSize", 1)
619+
.Default(0);
620+
}
621+
622+
} // anonymous namespace
623+
609624
// The demangler node hierarchy doesn't use LLVM's RTTI helper functions (as it
610625
// also needs to live in libcxxabi). By specializing this implementation here,
611626
// we can add support for these functions.
@@ -2188,7 +2203,23 @@ bool lowerBuiltinCallsToVariables(Module *M) {
21882203
Value *NewValue = Builder.CreateLoad(GVType, BV);
21892204
LLVM_DEBUG(dbgs() << "Transform: " << *CI << " => " << *NewValue << '\n');
21902205
if (IsVec) {
2191-
NewValue = Builder.CreateExtractElement(NewValue, CI->getArgOperand(0));
2206+
auto *GVVecTy = cast<FixedVectorType>(GVType);
2207+
ConstantInt *Bound = Builder.getInt32(GVVecTy->getNumElements());
2208+
// Create a select on the index first, to avoid undefined behaviour
2209+
// due to exceeding the vector size by the extractelement.
2210+
Value *IndexCmp = Builder.CreateICmpULT(CI->getArgOperand(0), Bound);
2211+
Constant *ZeroIndex =
2212+
ConstantInt::get(CI->getArgOperand(0)->getType(), 0);
2213+
Value *ExtractIndex =
2214+
Builder.CreateSelect(IndexCmp, CI->getArgOperand(0), ZeroIndex);
2215+
2216+
// Extract from builtin variable.
2217+
NewValue = Builder.CreateExtractElement(NewValue, ExtractIndex);
2218+
2219+
// Clamp to out-of-range value.
2220+
Constant *OutOfRangeVal = ConstantInt::get(
2221+
F.getReturnType(), getBuiltinOutOfRangeValue(BuiltinVarName));
2222+
NewValue = Builder.CreateSelect(IndexCmp, NewValue, OutOfRangeVal);
21922223
LLVM_DEBUG(dbgs() << *NewValue << '\n');
21932224
}
21942225
NewValue->takeName(CI);

test/DebugInfo/builtin-get-global-id.ll

+2-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ entry:
3333
; CHECK-NEXT: [[I3:%[0-9]]] = insertelement <3 x i64> [[I1]], i64 [[I2]], i32 1, !dbg [[DBG]]
3434
; CHECK-NEXT: [[I4:%[0-9]]] = call spir_func i64 @_Z13get_global_idj(i32 2) #1, !dbg [[DBG]]
3535
; CHECK-NEXT: [[I5:%[0-9]]] = insertelement <3 x i64> [[I3]], i64 [[I4]], i32 2, !dbg [[DBG]]
36-
; CHECK-NEXT: %call = extractelement <3 x i64> [[I5]], i32 0, !dbg [[DBG]]
36+
; CHECK-NEXT: [[I6:%[0-9]]] = extractelement <3 x i64> [[I5]], i32 0, !dbg [[DBG]]
37+
; CHECK-NEXT: %call = select i1 true, i64 [[I6]], i64 0, !dbg [[DBG]]
3738
store i64 %call, ptr %gid, align 8, !dbg !11
3839
ret void, !dbg !12
3940
}

test/get_global_size.cl

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: %clang_cc1 -triple spir64 -fdeclare-opencl-builtins -finclude-default-header -emit-llvm-bc %s -o %t.bc
2+
// RUN: llvm-spirv %t.bc -o %t.spv
3+
// RUN: spirv-val %t.spv
4+
// RUN: llvm-spirv %t.bc -spirv-text -o - | FileCheck %s
5+
6+
// Check that out of range dimension index values are handled according to the
7+
// OpenCL C specification.
8+
9+
kernel void ggs(global size_t *out, uint x) {
10+
// CHECK-DAG: Constant [[#]] [[#CONST64_1:]] 1 0
11+
// CHECK-DAG: Constant [[#]] [[#CONST3:]] 3
12+
// CHECK-DAG: Constant [[#]] [[#CONST0:]] 0
13+
// CHECK-DAG: ConstantTrue [[#]] [[#CONSTTRUE:]]
14+
// CHECK-DAG: ConstantFalse [[#]] [[#CONSTFALSE:]]
15+
16+
// CHECK: FunctionParameter [[#]] [[#PARAMOUT:]]
17+
// CHECK: FunctionParameter [[#]] [[#PARAMX:]]
18+
19+
// CHECK: Load [[#]] [[#LD0:]]
20+
// CHECK: CompositeExtract [[#]] [[#SCAL0:]] [[#LD0]] 0
21+
// CHECK: Select [[#]] [[#RES0:]] [[#CONSTTRUE]] [[#SCAL0]] [[#CONST64_1]]
22+
// CHECK: Store [[#]] [[#RES0]]
23+
out[0] = get_global_size(0);
24+
25+
// CHECK: Load [[#]] [[#LD1:]]
26+
// CHECK: CompositeExtract [[#]] [[#SCAL1:]] [[#LD1]] 0
27+
// CHECK: Select [[#]] [[#RES1:]] [[#CONSTFALSE]] [[#SCAL1]] [[#CONST64_1]]
28+
// CHECK: Store [[#]] [[#RES1]]
29+
out[1] = get_global_size(3);
30+
31+
// CHECK: Load [[#]] [[#LD2:]]
32+
// CHECK: ULessThan [[#]] [[#CMP:]] [[#PARAMX]] [[#CONST3]]
33+
// CHECK: Select [[#]] [[#SEL:]] [[#CMP]] [[#PARAMX]] [[#CONST0]]
34+
// CHECK: VectorExtractDynamic 2 [[#SCAL2:]] [[#LD2:]] [[#SEL]]
35+
// CHECK: Select [[#]] [[#RES2:]] [[#CMP]] [[#SCAL2]] [[#CONST64_1]]
36+
// CHECK: Store [[#]] [[#RES2]]
37+
out[2] = get_global_size(x);
38+
}

0 commit comments

Comments
 (0)