Skip to content

Commit 073d3e5

Browse files
committed
compiler: zero struct padding during map operations
Fixes #3358
1 parent 1f0bf9b commit 073d3e5

File tree

1 file changed

+89
-1
lines changed

1 file changed

+89
-1
lines changed

compiler/map.go

+89-1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
8989
// growth.
9090
mapKeyAlloca, mapKeyPtr, mapKeySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
9191
b.CreateStore(key, mapKeyAlloca)
92+
b.zeroUndefBytes(keyType, mapKeyAlloca)
9293
// Fetch the value from the hashmap.
9394
params := []llvm.Value{m, mapKeyPtr, mapValuePtr, mapValueSize}
9495
commaOkValue = b.createRuntimeCall("hashmapBinaryGet", params, "")
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
133134
// key can be compared with runtime.memequal
134135
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
135136
b.CreateStore(key, keyAlloca)
137+
b.zeroUndefBytes(keyType, keyAlloca)
136138
params := []llvm.Value{m, keyPtr, valuePtr}
137139
b.createRuntimeCall("hashmapBinarySet", params, "")
138140
b.emitLifetimeEnd(keyPtr, keySize)
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
161163
} else if hashmapIsBinaryKey(keyType) {
162164
keyAlloca, keyPtr, keySize := b.createTemporaryAlloca(key.Type(), "hashmap.key")
163165
b.CreateStore(key, keyAlloca)
166+
b.zeroUndefBytes(keyType, keyAlloca)
164167
params := []llvm.Value{m, keyPtr}
165168
b.createRuntimeCall("hashmapBinaryDelete", params, "")
166169
b.emitLifetimeEnd(keyPtr, keySize)
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
240243
}
241244

242245
// Returns true if this key type does not contain strings, interfaces etc., so
243-
// can be compared with runtime.memequal.
246+
// can be compared with runtime.memequal. Note that padding bytes are udnef
247+
// and can alter two "equal" structs being equal when compared with memequal.
244248
func hashmapIsBinaryKey(keyType types.Type) bool {
245249
switch keyType := keyType.(type) {
246250
case *types.Basic:
@@ -263,3 +267,87 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
263267
return false
264268
}
265269
}
270+
271+
func (b *builder) zeroUndefBytes(typ types.Type, ptr llvm.Value) error {
272+
273+
// We know that hashmapIsBinaryKey is true, so we only have to handle those types
274+
275+
// assert that we're a struct type
276+
277+
// iterate over all the fields
278+
// zero bytes if there is padding for this field
279+
// if fieled type is a compound type then recurse => need to support arrays
280+
281+
switch typ := typ.Underlying().(type) {
282+
283+
case *types.Basic:
284+
// no padding bytes
285+
return nil
286+
case *types.Pointer:
287+
// mo padding bytes
288+
return nil
289+
290+
case *types.Named:
291+
// zero underlying type
292+
return b.zeroUndefBytes(typ.Underlying(), ptr)
293+
294+
case *types.Array:
295+
llvmArrayType := b.getLLVMType(typ)
296+
llvmElemType := b.getLLVMType(typ.Elem())
297+
298+
base := ptr
299+
300+
// two kinds of padding to deal with in an array:
301+
// - padding within each element
302+
// - padding between elements
303+
304+
for i := uint64(0); i < uint64(typ.Len()); i++ {
305+
// for each element, first clear any undef bytes in the element itself
306+
offset := b.targetData.ElementOffset(llvmArrayType, int(i))
307+
llvmOffset := llvm.ConstInt(b.uintptrType, offset, false)
308+
ptr := b.CreateGEP(b.ctx.Int8Type(), base, []llvm.Value{llvmOffset}, "")
309+
310+
// recursively zero any padding bytes in this element
311+
b.zeroUndefBytes(typ.Elem(), ptr)
312+
313+
// check for padding between elements
314+
// TODO(dgryski): typeSizeEqualStoreSize ?
315+
if allocSize, storeSize := b.targetData.TypeAllocSize(llvmElemType), b.targetData.TypeStoreSize(llvmElemType); allocSize != storeSize {
316+
n := llvm.ConstInt(b.uintptrType, allocSize-storeSize, false)
317+
llvmSize := llvm.ConstInt(b.uintptrType, storeSize, false)
318+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmSize}, "")
319+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
320+
}
321+
}
322+
323+
case *types.Struct:
324+
llvmStructType := b.getLLVMType(typ)
325+
326+
base := ptr
327+
328+
for i := 0; i < typ.NumFields(); i++ {
329+
offset := b.targetData.ElementOffset(llvmStructType, int(i))
330+
331+
llvmOffset := llvm.ConstInt(b.uintptrType, offset, false)
332+
ptr := b.CreateGEP(b.ctx.Int8Type(), base, []llvm.Value{llvmOffset}, "")
333+
334+
// zero any undef bytes in this field
335+
fieldType := typ.Field(i).Type()
336+
b.zeroUndefBytes(fieldType, ptr)
337+
338+
// zero any undef bytes before the next field, if any
339+
if i < typ.NumFields()-1 {
340+
nextOffset := b.targetData.ElementOffset(llvmStructType, i+1)
341+
llvmElemType := b.getLLVMType(fieldType)
342+
if storeSize := b.targetData.TypeStoreSize(llvmElemType); (nextOffset - offset) != storeSize {
343+
n := llvm.ConstInt(b.uintptrType, (nextOffset-offset)-storeSize, false)
344+
llvmSize := llvm.ConstInt(b.uintptrType, storeSize, false)
345+
paddingStart := b.CreateGEP(b.ctx.Int8Type(), ptr, []llvm.Value{llvmSize}, "")
346+
b.createRuntimeCall("memzero", []llvm.Value{paddingStart, n}, "")
347+
}
348+
}
349+
}
350+
}
351+
352+
return nil
353+
}

0 commit comments

Comments
 (0)