Skip to content

Commit 6369e7f

Browse files
dgryskideadprogram
authored andcommitted
compiler: zero struct padding during map operations
Fixes #3358
1 parent 8b0acd9 commit 6369e7f

File tree

4 files changed

+286
-1
lines changed

4 files changed

+286
-1
lines changed

compiler/compiler_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) {
4949
{"goroutine.go", "cortex-m-qemu", "tasks"},
5050
{"channel.go", "", ""},
5151
{"gc.go", "", ""},
52+
{"zeromap.go", "", ""},
5253
}
5354
if goMinor >= 20 {
5455
tests = append(tests, testCase{"go1.20.go", "", ""})

compiler/map.go

+78-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
8989
// growth.
9090
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
9191
b.CreateStore(key, mapKeyAlloca)
92+
b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca)
9293
// Fetch the value from the hashmap.
9394
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
9495
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
133134
// key can be compared with runtime.memequal
134135
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
135136
b.CreateStore(key, keyAlloca)
137+
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
136138
params := []llvm.Value{m, keyPtr, valuePtr}
137139
b.createRuntimeCall("hashmapBinarySet", params, "")
138140
b.emitLifetimeEnd(keyPtr, keySize)
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
161163
} else if hashmapIsBinaryKey(keyType) {
162164
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
163165
b.CreateStore(key, keyAlloca)
166+
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
164167
params := []llvm.Value{m, keyPtr}
165168
b.createRuntimeCall("hashmapBinaryDelete", params, "")
166169
b.emitLifetimeEnd(keyPtr, keySize)
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
240243
}
241244

242245
// Returns true if this key type does not contain strings, interfaces etc., so
243-
// can be compared with runtime.memequal.
246+
// can be compared with runtime.memequal. Note that padding bytes are undef
247+
// and can alter two "equal" structs being equal when compared with memequal.
244248
func hashmapIsBinaryKey(keyType types.Type) bool {
245249
switch keyType := keyType.(type) {
246250
case *types.Basic:
@@ -263,3 +267,76 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
263267
return false
264268
}
265269
}
270+
271+
func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error {
272+
// We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
273+
// To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
274+
// offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
275+
// we handle nested types. Next, we determine if there are any padding bytes before the next
276+
// element and zero those as well.
277+
278+
zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)
279+
280+
switch llvmType.TypeKind() {
281+
case llvm.IntegerTypeKind:
282+
// no padding bytes
283+
return nil
284+
case llvm.PointerTypeKind:
285+
// mo padding bytes
286+
return nil
287+
case llvm.ArrayTypeKind:
288+
llvmArrayType := llvmType
289+
llvmElemType := llvmType.ElementType()
290+
291+
for i := 0; i < llvmArrayType.ArrayLength(); i++ {
292+
idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
293+
elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "")
294+
295+
// zero any padding bytes in this element
296+
b.zeroUndefBytes(llvmElemType, elemPtr)
297+
}
298+
299+
case llvm.StructTypeKind:
300+
llvmStructType := llvmType
301+
numFields := llvmStructType.StructElementTypesCount()
302+
llvmElementTypes := llvmStructType.StructElementTypes()
303+
304+
for i := 0; i < numFields; i++ {
305+
idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)
306+
elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "")
307+
308+
// zero any padding bytes in this field
309+
llvmElemType := llvmElementTypes[i]
310+
b.zeroUndefBytes(llvmElemType, elemPtr)
311+
312+
// zero any padding bytes before the next field, if any
313+
offset := b.targetData.ElementOffset(llvmStructType, i)
314+
storeSize := b.targetData.TypeStoreSize(llvmElemType)
315+
fieldEndOffset := offset + storeSize
316+
317+
var nextOffset uint64
318+
if i < numFields-1 {
319+
nextOffset = b.targetData.ElementOffset(llvmStructType, i+1)
320+
} else {
321+
// Last field? Next offset is the total size of the allcoate struct.
322+
nextOffset = b.targetData.TypeAllocSize(llvmStructType)
323+
}
324+
325+
if fieldEndOffset != nextOffset {
326+
n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false)
327+
llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false)
328+
gepPtr := elemPtr
329+
if gepPtr.Type() != b.i8ptrType {
330+
gepPtr = b.CreateBitCast(gepPtr, b.i8ptrType, "") // LLVM 14
331+
}
332+
paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), gepPtr, []llvm.Value{llvmStoreSize}, "")
333+
if paddingStart.Type() != b.i8ptrType {
334+
paddingStart = b.CreateBitCast(paddingStart, b.i8ptrType, "") // LLVM 14
335+
}
336+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
337+
}
338+
}
339+
}
340+
341+
return nil
342+
}

compiler/testdata/zeromap.go

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package main
2+
3+
type hasPadding struct {
4+
b1 bool
5+
i int
6+
b2 bool
7+
}
8+
9+
type nestedPadding struct {
10+
b bool
11+
hasPadding
12+
i int
13+
}
14+
15+
//go:noinline
16+
func testZeroGet(m map[hasPadding]int, s hasPadding) int {
17+
return m[s]
18+
}
19+
20+
//go:noinline
21+
func testZeroSet(m map[hasPadding]int, s hasPadding) {
22+
m[s] = 5
23+
}
24+
25+
//go:noinline
26+
func testZeroArrayGet(m map[[2]hasPadding]int, s [2]hasPadding) int {
27+
return m[s]
28+
}
29+
30+
//go:noinline
31+
func testZeroArraySet(m map[[2]hasPadding]int, s [2]hasPadding) {
32+
m[s] = 5
33+
}
34+
35+
func main() {
36+
37+
}

compiler/testdata/zeromap.ll

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
; ModuleID = 'zeromap.go'
2+
source_filename = "zeromap.go"
3+
target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
4+
target triple = "wasm32-unknown-wasi"
5+
6+
%main.hasPadding = type { i1, i32, i1 }
7+
8+
declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0
9+
10+
declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0
11+
12+
; Function Attrs: nounwind
13+
define hidden void @main.init(ptr %context) unnamed_addr #1 {
14+
entry:
15+
ret void
16+
}
17+
18+
; Function Attrs: noinline nounwind
19+
define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
20+
entry:
21+
%hashmap.key = alloca %main.hasPadding, align 8
22+
%hashmap.value = alloca i32, align 4
23+
%s = alloca %main.hasPadding, align 8
24+
%0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
25+
%1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
26+
%2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
27+
%stackalloc = alloca i8, align 1
28+
store %main.hasPadding zeroinitializer, ptr %s, align 8
29+
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
30+
store %main.hasPadding %2, ptr %s, align 8
31+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
32+
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
33+
store %main.hasPadding %2, ptr %hashmap.key, align 8
34+
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
35+
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
36+
%4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
37+
call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
38+
%5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
39+
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
40+
%6 = load i32, ptr %hashmap.value, align 4
41+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
42+
ret i32 %6
43+
}
44+
45+
; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
46+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3
47+
48+
declare void @runtime.memzero(ptr, i32, ptr) #0
49+
50+
declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0
51+
52+
; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
53+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3
54+
55+
; Function Attrs: noinline nounwind
56+
define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
57+
entry:
58+
%hashmap.key = alloca %main.hasPadding, align 8
59+
%hashmap.value = alloca i32, align 4
60+
%s = alloca %main.hasPadding, align 8
61+
%0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
62+
%1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
63+
%2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
64+
%stackalloc = alloca i8, align 1
65+
store %main.hasPadding zeroinitializer, ptr %s, align 8
66+
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
67+
store %main.hasPadding %2, ptr %s, align 8
68+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
69+
store i32 5, ptr %hashmap.value, align 4
70+
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
71+
store %main.hasPadding %2, ptr %hashmap.key, align 8
72+
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
73+
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
74+
%4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
75+
call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
76+
call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
77+
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
78+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
79+
ret void
80+
}
81+
82+
declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0
83+
84+
; Function Attrs: noinline nounwind
85+
define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
86+
entry:
87+
%hashmap.key = alloca [2 x %main.hasPadding], align 8
88+
%hashmap.value = alloca i32, align 4
89+
%s1 = alloca [2 x %main.hasPadding], align 8
90+
%stackalloc = alloca i8, align 1
91+
store %main.hasPadding zeroinitializer, ptr %s1, align 8
92+
%s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
93+
store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
94+
call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
95+
%s.elt = extractvalue [2 x %main.hasPadding] %s, 0
96+
store %main.hasPadding %s.elt, ptr %s1, align 8
97+
%s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
98+
%s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
99+
store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
100+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
101+
call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
102+
%s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
103+
store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
104+
%hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
105+
%s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
106+
store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
107+
%0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
108+
call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
109+
%1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
110+
call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
111+
%2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
112+
call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
113+
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
114+
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
115+
%4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
116+
call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
117+
%5 = load i32, ptr %hashmap.value, align 4
118+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
119+
ret i32 %5
120+
}
121+
122+
; Function Attrs: noinline nounwind
123+
define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
124+
entry:
125+
%hashmap.key = alloca [2 x %main.hasPadding], align 8
126+
%hashmap.value = alloca i32, align 4
127+
%s1 = alloca [2 x %main.hasPadding], align 8
128+
%stackalloc = alloca i8, align 1
129+
store %main.hasPadding zeroinitializer, ptr %s1, align 8
130+
%s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
131+
store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
132+
call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
133+
%s.elt = extractvalue [2 x %main.hasPadding] %s, 0
134+
store %main.hasPadding %s.elt, ptr %s1, align 8
135+
%s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
136+
%s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
137+
store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
138+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
139+
store i32 5, ptr %hashmap.value, align 4
140+
call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
141+
%s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
142+
store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
143+
%hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
144+
%s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
145+
store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
146+
%0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
147+
call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
148+
%1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
149+
call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
150+
%2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
151+
call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
152+
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
153+
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
154+
call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
155+
call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
156+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
157+
ret void
158+
}
159+
160+
; Function Attrs: nounwind
161+
define hidden void @main.main(ptr %context) unnamed_addr #1 {
162+
entry:
163+
ret void
164+
}
165+
166+
attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
167+
attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
168+
attributes #2 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
169+
attributes #3 = { argmemonly nocallback nofree nosync nounwind willreturn }
170+
attributes #4 = { nounwind }

0 commit comments

Comments
 (0)