Skip to content

compiler: zero struct padding during map operations #3437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/compiler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func TestCompiler(t *testing.T) {
{"goroutine.go", "cortex-m-qemu", "tasks"},
{"channel.go", "", ""},
{"gc.go", "", ""},
{"zeromap.go", "", ""},
}
if goMinor >= 20 {
tests = append(tests, testCase{"go1.20.go", "", ""})
Expand Down
79 changes: 78 additions & 1 deletion compiler/map.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
// growth.
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, mapKeyAlloca)
b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca)
// Fetch the value from the hashmap.
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
Expand Down Expand Up @@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
// key can be compared with runtime.memequal
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, keyAlloca)
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
params := []llvm.Value{m, keyPtr, valuePtr}
b.createRuntimeCall("hashmapBinarySet", params, "")
b.emitLifetimeEnd(keyPtr, keySize)
Expand Down Expand Up @@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
} else if hashmapIsBinaryKey(keyType) {
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
b.CreateStore(key, keyAlloca)
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
params := []llvm.Value{m, keyPtr}
b.createRuntimeCall("hashmapBinaryDelete", params, "")
b.emitLifetimeEnd(keyPtr, keySize)
Expand Down Expand Up @@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
}

// Returns true if this key type does not contain strings, interfaces etc., so
// can be compared with runtime.memequal.
// can be compared with runtime.memequal. Note that padding bytes are undef
// and can alter two "equal" structs being equal when compared with memequal.
func hashmapIsBinaryKey(keyType types.Type) bool {
switch keyType := keyType.(type) {
case *types.Basic:
Expand All @@ -263,3 +267,76 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
return false
}
}

func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error {
// We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
// To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
// offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
// we handle nested types. Next, we determine if there are any padding bytes before the next
// element and zero those as well.

zero := llvm.ConstInt(b.ctx.Int32Type(), 0, false)

switch llvmType.TypeKind() {
case llvm.IntegerTypeKind:
// no padding bytes
return nil
case llvm.PointerTypeKind:
// mo padding bytes
return nil
case llvm.ArrayTypeKind:
llvmArrayType := llvmType
llvmElemType := llvmType.ElementType()

for i := 0; i < llvmArrayType.ArrayLength(); i++ {
idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
elemPtr := b.CreateInBoundsGEP(llvmArrayType, ptr, []llvm.Value{zero, idx}, "")

// zero any padding bytes in this element
b.zeroUndefBytes(llvmElemType, elemPtr)
}

case llvm.StructTypeKind:
llvmStructType := llvmType
numFields := llvmStructType.StructElementTypesCount()
llvmElementTypes := llvmStructType.StructElementTypes()

for i := 0; i < numFields; i++ {
idx := llvm.ConstInt(b.ctx.Int32Type(), uint64(i), false)
elemPtr := b.CreateInBoundsGEP(llvmStructType, ptr, []llvm.Value{zero, idx}, "")

// zero any padding bytes in this field
llvmElemType := llvmElementTypes[i]
b.zeroUndefBytes(llvmElemType, elemPtr)

// zero any padding bytes before the next field, if any
offset := b.targetData.ElementOffset(llvmStructType, i)
storeSize := b.targetData.TypeStoreSize(llvmElemType)
fieldEndOffset := offset + storeSize

var nextOffset uint64
if i < numFields-1 {
nextOffset = b.targetData.ElementOffset(llvmStructType, i+1)
} else {
// Last field? Next offset is the total size of the allcoate struct.
nextOffset = b.targetData.TypeAllocSize(llvmStructType)
}

if fieldEndOffset != nextOffset {
n := llvm.ConstInt(b.uintptrType, nextOffset-fieldEndOffset, false)
llvmStoreSize := llvm.ConstInt(b.uintptrType, storeSize, false)
gepPtr := elemPtr
if gepPtr.Type() != b.i8ptrType {
gepPtr = b.CreateBitCast(gepPtr, b.i8ptrType, "") // LLVM 14
}
paddingStart := b.CreateInBoundsGEP(b.ctx.Int8Type(), gepPtr, []llvm.Value{llvmStoreSize}, "")
if paddingStart.Type() != b.i8ptrType {
paddingStart = b.CreateBitCast(paddingStart, b.i8ptrType, "") // LLVM 14
}
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
}
}
}

return nil
}
37 changes: 37 additions & 0 deletions compiler/testdata/zeromap.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package main

type hasPadding struct {
b1 bool
i int
b2 bool
}

type nestedPadding struct {
b bool
hasPadding
i int
}

//go:noinline
func testZeroGet(m map[hasPadding]int, s hasPadding) int {
return m[s]
}

//go:noinline
func testZeroSet(m map[hasPadding]int, s hasPadding) {
m[s] = 5
}

//go:noinline
func testZeroArrayGet(m map[[2]hasPadding]int, s [2]hasPadding) int {
return m[s]
}

//go:noinline
func testZeroArraySet(m map[[2]hasPadding]int, s [2]hasPadding) {
m[s] = 5
}

func main() {

}
170 changes: 170 additions & 0 deletions compiler/testdata/zeromap.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
; ModuleID = 'zeromap.go'
source_filename = "zeromap.go"
target datalayout = "e-m:e-p:32:32-p10:8:8-p20:8:8-i64:64-n32:64-S128-ni:1:10:20"
target triple = "wasm32-unknown-wasi"

%main.hasPadding = type { i1, i32, i1 }

declare noalias nonnull ptr @runtime.alloc(i32, ptr, ptr) #0

declare void @runtime.trackPointer(ptr nocapture readonly, ptr, ptr) #0

; Function Attrs: nounwind
define hidden void @main.init(ptr %context) unnamed_addr #1 {
entry:
ret void
}

; Function Attrs: noinline nounwind
define hidden i32 @main.testZeroGet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca %main.hasPadding, align 8
%hashmap.value = alloca i32, align 4
%s = alloca %main.hasPadding, align 8
%0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
%1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
%2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s, align 8
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
store %main.hasPadding %2, ptr %s, align 8
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
store %main.hasPadding %2, ptr %hashmap.key, align 8
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
%4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
%5 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
%6 = load i32, ptr %hashmap.value, align 4
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret i32 %6
}

; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3

declare void @runtime.memzero(ptr, i32, ptr) #0

declare i1 @runtime.hashmapBinaryGet(ptr dereferenceable_or_null(40), ptr, ptr, i32, ptr) #0

; Function Attrs: argmemonly nocallback nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0(i64 immarg, ptr nocapture) #3

; Function Attrs: noinline nounwind
define hidden void @main.testZeroSet(ptr dereferenceable_or_null(40) %m, i1 %s.b1, i32 %s.i, i1 %s.b2, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca %main.hasPadding, align 8
%hashmap.value = alloca i32, align 4
%s = alloca %main.hasPadding, align 8
%0 = insertvalue %main.hasPadding zeroinitializer, i1 %s.b1, 0
%1 = insertvalue %main.hasPadding %0, i32 %s.i, 1
%2 = insertvalue %main.hasPadding %1, i1 %s.b2, 2
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s, align 8
call void @runtime.trackPointer(ptr nonnull %s, ptr nonnull %stackalloc, ptr undef) #4
store %main.hasPadding %2, ptr %s, align 8
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
store i32 5, ptr %hashmap.value, align 4
call void @llvm.lifetime.start.p0(i64 12, ptr nonnull %hashmap.key)
store %main.hasPadding %2, ptr %hashmap.key, align 8
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
%4 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %4, i32 3, ptr undef) #4
call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 12, ptr nonnull %hashmap.key)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret void
}

declare void @runtime.hashmapBinarySet(ptr dereferenceable_or_null(40), ptr, ptr, ptr) #0

; Function Attrs: noinline nounwind
define hidden i32 @main.testZeroArrayGet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca [2 x %main.hasPadding], align 8
%hashmap.value = alloca i32, align 4
%s1 = alloca [2 x %main.hasPadding], align 8
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s1, align 8
%s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
%s.elt = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt, ptr %s1, align 8
%s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
%s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
%s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
%hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
%s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
%0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
%1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
%2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
%4 = call i1 @runtime.hashmapBinaryGet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, i32 4, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
%5 = load i32, ptr %hashmap.value, align 4
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret i32 %5
}

; Function Attrs: noinline nounwind
define hidden void @main.testZeroArraySet(ptr dereferenceable_or_null(40) %m, [2 x %main.hasPadding] %s, ptr %context) unnamed_addr #2 {
entry:
%hashmap.key = alloca [2 x %main.hasPadding], align 8
%hashmap.value = alloca i32, align 4
%s1 = alloca [2 x %main.hasPadding], align 8
%stackalloc = alloca i8, align 1
store %main.hasPadding zeroinitializer, ptr %s1, align 8
%s1.repack2 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
store %main.hasPadding zeroinitializer, ptr %s1.repack2, align 4
call void @runtime.trackPointer(ptr nonnull %s1, ptr nonnull %stackalloc, ptr undef) #4
%s.elt = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt, ptr %s1, align 8
%s1.repack3 = getelementptr inbounds [2 x %main.hasPadding], ptr %s1, i32 0, i32 1
%s.elt4 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt4, ptr %s1.repack3, align 4
call void @llvm.lifetime.start.p0(i64 4, ptr nonnull %hashmap.value)
store i32 5, ptr %hashmap.value, align 4
call void @llvm.lifetime.start.p0(i64 24, ptr nonnull %hashmap.key)
%s.elt7 = extractvalue [2 x %main.hasPadding] %s, 0
store %main.hasPadding %s.elt7, ptr %hashmap.key, align 8
%hashmap.key.repack8 = getelementptr inbounds [2 x %main.hasPadding], ptr %hashmap.key, i32 0, i32 1
%s.elt9 = extractvalue [2 x %main.hasPadding] %s, 1
store %main.hasPadding %s.elt9, ptr %hashmap.key.repack8, align 4
%0 = getelementptr inbounds i8, ptr %hashmap.key, i32 1
call void @runtime.memzero(ptr nonnull %0, i32 3, ptr undef) #4
%1 = getelementptr inbounds i8, ptr %hashmap.key, i32 9
call void @runtime.memzero(ptr nonnull %1, i32 3, ptr undef) #4
%2 = getelementptr inbounds i8, ptr %hashmap.key, i32 13
call void @runtime.memzero(ptr nonnull %2, i32 3, ptr undef) #4
%3 = getelementptr inbounds i8, ptr %hashmap.key, i32 21
call void @runtime.memzero(ptr nonnull %3, i32 3, ptr undef) #4
call void @runtime.hashmapBinarySet(ptr %m, ptr nonnull %hashmap.key, ptr nonnull %hashmap.value, ptr undef) #4
call void @llvm.lifetime.end.p0(i64 24, ptr nonnull %hashmap.key)
call void @llvm.lifetime.end.p0(i64 4, ptr nonnull %hashmap.value)
ret void
}

; Function Attrs: nounwind
define hidden void @main.main(ptr %context) unnamed_addr #1 {
entry:
ret void
}

attributes #0 = { "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
attributes #1 = { nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
attributes #2 = { noinline nounwind "target-features"="+bulk-memory,+nontrapping-fptoint,+sign-ext" }
attributes #3 = { argmemonly nocallback nofree nosync nounwind willreturn }
attributes #4 = { nounwind }