Skip to content

Commit b50a471

Browse files
committed
compiler: zero struct padding during map operations
Fixes #3358
1 parent 1f0bf9b commit b50a471

File tree

1 file changed

+62
-1
lines changed

1 file changed

+62
-1
lines changed

compiler/map.go

+62-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
8989
// growth.
9090
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
9191
b.CreateStore(key, mapKeyAlloca)
92+
b.zeroUndefBytes(b.getLLVMType(keyType), mapKeyAlloca)
9293
// Fetch the value from the hashmap.
9394
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
9495
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
133134
// key can be compared with runtime.memequal
134135
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
135136
b.CreateStore(key, keyAlloca)
137+
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
136138
params := []llvm.Value{m, keyPtr, valuePtr}
137139
b.createRuntimeCall("hashmapBinarySet", params, "")
138140
b.emitLifetimeEnd(keyPtr, keySize)
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
161163
} else if hashmapIsBinaryKey(keyType) {
162164
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
163165
b.CreateStore(key, keyAlloca)
166+
b.zeroUndefBytes(b.getLLVMType(keyType), keyAlloca)
164167
params := []llvm.Value{m, keyPtr}
165168
b.createRuntimeCall("hashmapBinaryDelete", params, "")
166169
b.emitLifetimeEnd(keyPtr, keySize)
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
240243
}
241244

242245
// Returns true if this key type does not contain strings, interfaces etc., so
243-
// can be compared with runtime.memequal.
246+
// can be compared with runtime.memequal. Note that padding bytes are undef
247+
// and can alter two "equal" structs being equal when compared with memequal.
244248
func hashmapIsBinaryKey(keyType types.Type) bool {
245249
switch keyType := keyType.(type) {
246250
case *types.Basic:
@@ -263,3 +267,60 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
263267
return false
264268
}
265269
}
270+
271+
func (b *builder) zeroUndefBytes(llvmType llvm.Type, ptr llvm.Value) error {
272+
// We know that hashmapIsBinaryKey is true, so we only have to handle those types that can show up there.
273+
// To zero all undefined bytes, we iterate over all the fields in the type. For each element, compute the
274+
// offset of that element. If it's Basic type, there are no internal padding bytes. For compound types, we recurse to ensure
275+
// we handle nested types. Next, we determine if there are any padding bytes before the next
276+
// element and zero those as well.
277+
278+
switch llvmType.TypeKind() {
279+
case llvm.IntegerTypeKind:
280+
// no padding bytes
281+
return nil
282+
case llvm.PointerTypeKind:
283+
// mo padding bytes
284+
return nil
285+
case llvm.ArrayTypeKind:
286+
llvmArrayType := llvmType
287+
llvmElemType := llvmType.ElementType()
288+
289+
for i := 0; i < llvmArrayType.ArrayLength(); i++ {
290+
idx := llvm.ConstInt(b.uintptrType, uint64(i), false)
291+
elemPtr := b.CreateGEP(llvmArrayType, ptr, []llvm.Value{idx}, "")
292+
293+
// recursively zero any padding bytes in this element
294+
b.zeroUndefBytes(llvmElemType, elemPtr)
295+
}
296+
297+
case llvm.StructTypeKind:
298+
llvmStructType := llvmType
299+
numFields := llvmStructType.StructElementTypesCount()
300+
llvmElementTypes := llvmStructType.StructElementTypes()
301+
302+
for i := 0; i < llvmStructType.StructElementTypesCount(); i++ {
303+
offset := b.targetData.ElementOffset(llvmStructType, i)
304+
llvmOffset := llvm.ConstInt(b.uintptrType, offset, false)
305+
elemPtr := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmOffset}, "")
306+
307+
// zero any undef bytes in this field
308+
llvmElemType := llvmElementTypes[i]
309+
b.zeroUndefBytes(llvmElemType, elemPtr)
310+
311+
// zero any undef bytes before the next field, if any
312+
if i < numFields-1 {
313+
nextOffset := b.targetData.ElementOffset(llvmStructType, i+1)
314+
315+
if storeSize := b.targetData.TypeStoreSize(llvmElemType); (nextOffset - offset) != storeSize {
316+
n := llvm.ConstInt(b.uintptrType, (nextOffset-offset)-storeSize, false)
317+
llvmSize := llvm.ConstInt(b.uintptrType, storeSize, false)
318+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmSize}, "")
319+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
320+
}
321+
}
322+
}
323+
}
324+
325+
return nil
326+
}

0 commit comments

Comments
 (0)