Skip to content

Commit 1e40948

Browse files
committed
compiler: zero struct padding during map operations
Fixes #3358
1 parent cecb80b commit 1e40948

File tree

4 files changed

+219
-1
lines changed

4 files changed

+219
-1
lines changed

compiler/compiler_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) {
4949
{"goroutine.go", "cortex-m-qemu", "tasks"},
5050
{"channel.go", "", ""},
5151
{"gc.go", "", ""},
52+
{"zeromap.go", "", ""},
5253
}
5354
if goMinor >= 20 {
5455
tests = append(tests, testCase{"go1.20.go", "", ""})

compiler/map.go

+73-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
8989
// growth.
9090
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
9191
b.CreateStore(key, mapKeyAlloca)
92+
b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca)
9293
// Fetch the value from the hashmap.
9394
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
9495
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
133134
// key can be compared with runtime.memequal
134135
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
135136
b.CreateStore(key, keyAlloca)
137+
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
136138
params := []llvm.Value{m, keyPtr, valuePtr}
137139
b.createRuntimeCall("hashmapBinarySet", params, "")
138140
b.emitLifetimeEnd(keyPtr, keySize)
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
161163
} else if hashmapIsBinaryKey(keyType) {
162164
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
163165
b.CreateStore(key, keyAlloca)
166+
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
164167
params := []llvm.Value{m, keyPtr}
165168
b.createRuntimeCall("hashmapBinaryDelete", params, "")
166169
b.emitLifetimeEnd(keyPtr, keySize)
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
240243
}
241244

242245
// Returns true if this key type does not contain strings, interfaces etc., so
243-
// can be compared with runtime.memequal.
246+
// can be compared with runtime.memequal. Note that padding bytes are undef
247+
// and can alter two "equal" structs being equal when compared with memequal.
244248
func hashmapIsBinaryKey(keyType types.Type) bool {
245249
switch keyType := keyType.(type) {
246250
case *types.Basic:
@@ -263,3 +267,71 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
263267
return false
264268
}
265269
}
270+
271+
func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error {
272+
// We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
273+
// To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
274+
// offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
275+
// we handle nested types. Next, we determine if there are any padding bytes before the next
276+
// element and zero those as well.
277+
278+
switch llvmType.TypeKind() {
279+
case llvm.IntegerTypeKind:
280+
// no padding bytes
281+
return nil
282+
case llvm.PointerTypeKind:
283+
// mo padding bytes
284+
return nil
285+
case llvm.ArrayTypeKind:
286+
llvmArrayType := llvmType
287+
llvmElemType := llvmType.ElementType()
288+
289+
for i := 0; i < llvmArrayType.ArrayLength(); i++ {
290+
idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
291+
elemPtr := b.CreateGEP(llvmArrayType, ptr, []llvm.Value{idx}, "")
292+
293+
// zero any padding bytes in this element
294+
b.zeroUndefBytes(llvmElemType, elemPtr)
295+
}
296+
297+
case llvm.StructTypeKind:
298+
llvmStructType := llvmType
299+
numFields := llvmStructType.StructElementTypesCount()
300+
llvmElementTypes := llvmStructType.StructElementTypes()
301+
302+
for i := 0; i < llvmStructType.StructElementTypesCount(); i++ {
303+
offset := b.targetData.ElementOffset(llvmStructType, i)
304+
llvmOffset := llvm.ConstInt(b.uintptrType, offset, false)
305+
elemPtr := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmOffset}, "")
306+
307+
// zero any padding bytes in this field
308+
llvmElemType := llvmElementTypes[i]
309+
b.zeroUndefBytes(llvmElemType, elemPtr)
310+
311+
// zero any padding bytes before the next field, if any
312+
if i < numFields-1 {
313+
nextOffset := b.targetData.ElementOffset(llvmStructType, i+1)
314+
if storeSize := b.targetData.TypeStoreSize(llvmElemType); offset+storeSize != nextOffset {
315+
n := llvm.ConstInt(b.uintptrType, nextOffset-(offset+storeSize), false)
316+
llvmSize := llvm.ConstInt(b.uintptrType, storeSize, false)
317+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), elemPtr, []llvm.Value{llvmSize}, "")
318+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
319+
}
320+
}
321+
}
322+
323+
// zero any padding bytes between end of last field and end of struct
324+
allocSize := b.targetData.TypeAllocSize(llvmStructType)
325+
lastElemOffset := b.targetData.ElementOffset(llvmStructType, numFields-1)
326+
lastElemStore := b.targetData.TypeStoreSize(llvmElementTypes[numFields-1])
327+
328+
if allocSize != lastElemOffset+lastElemStore {
329+
n := llvm.ConstInt(b.uintptrType, allocSize-(lastElemOffset+lastElemStore), false)
330+
llvmSize := llvm.ConstInt(b.uintptrType, lastElemOffset+lastElemStore, false)
331+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmSize}, "")
332+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
333+
}
334+
}
335+
336+
return nil
337+
}

compiler/testdata/zeromap.go

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package main
2+
3+
type hasPadding struct {
4+
b1 bool
5+
i int
6+
b2 bool
7+
}
8+
9+
func main() {
10+
m := make(map[hasPadding]int)
11+
var s hasPadding
12+
13+
for i := 0; i < 10; i++ {
14+
s.b1 = i&1 == 0
15+
s.i = i
16+
s.b2 = i&1 == 1
17+
m[s]++
18+
}
19+
println(len(m))
20+
}

compiler/testdata/zeromap.ll

+125
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
; ModuleID = 'zeromap.go'
2+
source_filename = "zeromap.go"
3+
target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
4+
target triple = "wasm32-unknown-wasi"
5+
6+
%main.hasPadding = type { i1, i32, i1 }
7+
8+
declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0
9+
10+
declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0
11+
12+
; Function Attrs: nounwind
13+
define hidden void @main.init(ptr %context) unnamed_addr #1 {
14+
entry:
15+
ret void
16+
}
17+
18+
; Function Attrs: nounwind
19+
define hidden void @main.main(ptr %context) unnamed_addr #1 {
20+
entry:
21+
%hashmap.key6 = alloca %main.hasPadding, align 8
22+
%hashmap.value5 = alloca i32, align 4
23+
%hashmap.key = alloca %main.hasPadding, align 8
24+
%hashmap.value = alloca i32, align 4
25+
%s = alloca %main.hasPadding, align 8
26+
%stackalloc = alloca i8, align 1
27+
%0 = call ptr @runtime.hashmapMake(i32 12, i32 4, i32 8, i8 0, ptr undef) #3
28+
call void @runtime.trackPointer(ptr %0, ptr nonnull %stackalloc, ptr undef) #3
29+
store %main.hasPadding zeroinitializer, ptr %s, align 8
30+
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #3
31+
br label %for.loop
32+
33+
for.loop: ; preds = %store.next4, %entry
34+
%1 = phi i32 [ 0, %entry ], [ %17, %store.next4 ]
35+
%2 = icmp slt i32 %1, 10
36+
br i1 %2, label %for.body, label %for.done
37+
38+
for.body: ; preds = %for.loop
39+
br i1 false, label %store.throw, label %store.next
40+
41+
store.next: ; preds = %for.body
42+
%3 = and i32 %1, 1
43+
%4 = icmp eq i32 %3, 0
44+
store i1 %4, ptr %s, align 8
45+
br i1 false, label %store.throw1, label %store.next2
46+
47+
store.next2: ; preds = %store.next
48+
%5 = getelementptr inbounds %main.hasPadding, ptr %s, i32 0, i32 1
49+
store i32 %1, ptr %5, align 4
50+
br i1 false, label %store.throw3, label %store.next4
51+
52+
store.next4: ; preds = %store.next2
53+
%6 = getelementptr inbounds %main.hasPadding, ptr %s, i32 0, i32 2
54+
%7 = and i32 %1, 1
55+
%8 = icmp ne i32 %7, 0
56+
store i1 %8, ptr %6, align 8
57+
%9 = load %main.hasPadding, ptr %s, align 8
58+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
59+
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
60+
store %main.hasPadding %9, ptr %hashmap.key, align 8
61+
%10 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
62+
call void @runtime.memzero(ptr nonnull %10, i32 3, ptr undef) #3
63+
%11 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
64+
call void @runtime.memzero(ptr nonnull %11, i32 3, ptr undef) #3
65+
%12 = call i1 @runtime.hashmapBinaryGet(ptr %0, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #3
66+
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
67+
%13 = load i32, ptr %hashmap.value, align 4
68+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
69+
%14 = add i32 %13, 1
70+
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value5)
71+
store i32 %14, ptr %hashmap.value5, align 4
72+
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key6)
73+
store %main.hasPadding %9, ptr %hashmap.key6, align 8
74+
%15 = getelementptr inbounds i8, ptr %hashmap.key6, i32 1
75+
call void @runtime.memzero(ptr nonnull %15, i32 3, ptr undef) #3
76+
%16 = getelementptr inbounds i8, ptr %hashmap.key6, i32 9
77+
call void @runtime.memzero(ptr nonnull %16, i32 3, ptr undef) #3
78+
call void @runtime.hashmapBinarySet(ptr %0, ptr nonnull %hashmap.key6, ptr nonnull %hashmap.value5, ptr undef) #3
79+
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key6)
80+
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value5)
81+
%17 = add i32 %1, 1
82+
br label %for.loop
83+
84+
for.done: ; preds = %for.loop
85+
%len = call i32 @runtime.hashmapLen(ptr %0, ptr undef) #3
86+
call void @runtime.printint32(i32 %len, ptr undef) #3
87+
call void @runtime.printnl(ptr undef) #3
88+
ret void
89+
90+
store.throw: ; preds = %for.body
91+
unreachable
92+
93+
store.throw1: ; preds = %store.next
94+
unreachable
95+
96+
store.throw3: ; preds = %store.next2
97+
unreachable
98+
}
99+
100+
declare ptr @runtime.hashmapMake(i32, i32, i32, i8, ptr) #0
101+
102+
declare void @runtime.nilPanic(ptr) #0
103+
104+
; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
105+
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #2
106+
107+
declare void @runtime.memzero(ptr, i32, ptr) #0
108+
109+
declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0
110+
111+
; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
112+
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #2
113+
114+
declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0
115+
116+
declare i32 @runtime.hashmapLen(ptr dereferenceable_or_null(40), ptr) #0
117+
118+
declare void @runtime.printint32(i32, ptr) #0
119+
120+
declare void @runtime.printnl(ptr) #0
121+
122+
attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
123+
attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
124+
attributes #2 = { argmemonly nocallback nofree nosync nounwind willreturn }
125+
attributes #3 = { nounwind }

0 commit comments

Comments
 (0)