@@ -89,6 +89,7 @@ func (b *builder) createMapLookup(keyType, valueType types.Type, m, key llvm.Val
8989 // growth.
9090 mapKeyAlloca , mapKeyPtr , mapKeySize := b .createTemporaryAlloca (key .Type (), "hashmap.key" )
9191 b .CreateStore (key , mapKeyAlloca )
92+ b .zeroUndefBytes (keyType , mapKeyAlloca )
9293 // Fetch the value from the hashmap.
9394 params := []llvm.Value {m , mapKeyPtr , mapValuePtr , mapValueSize }
9495 commaOkValue = b .createRuntimeCall ("hashmapBinaryGet" , params , "" )
@@ -133,6 +134,7 @@ func (b *builder) createMapUpdate(keyType types.Type, m, key, value llvm.Value,
133134 // key can be compared with runtime.memequal
134135 keyAlloca , keyPtr , keySize := b .createTemporaryAlloca (key .Type (), "hashmap.key" )
135136 b .CreateStore (key , keyAlloca )
137+ b .zeroUndefBytes (keyType , keyAlloca )
136138 params := []llvm.Value {m , keyPtr , valuePtr }
137139 b .createRuntimeCall ("hashmapBinarySet" , params , "" )
138140 b .emitLifetimeEnd (keyPtr , keySize )
@@ -161,6 +163,7 @@ func (b *builder) createMapDelete(keyType types.Type, m, key llvm.Value, pos tok
161163 } else if hashmapIsBinaryKey (keyType ) {
162164 keyAlloca , keyPtr , keySize := b .createTemporaryAlloca (key .Type (), "hashmap.key" )
163165 b .CreateStore (key , keyAlloca )
166+ b .zeroUndefBytes (keyType , keyAlloca )
164167 params := []llvm.Value {m , keyPtr }
165168 b .createRuntimeCall ("hashmapBinaryDelete" , params , "" )
166169 b .emitLifetimeEnd (keyPtr , keySize )
@@ -240,7 +243,8 @@ func (b *builder) createMapIteratorNext(rangeVal ssa.Value, llvmRangeVal, it llv
240243}
241244
242245// Returns true if this key type does not contain strings, interfaces etc., so
243- // can be compared with runtime.memequal.
246+ // can be compared with runtime.memequal. Note that padding bytes are udnef
247+ // and can alter two "equal" structs being equal when compared with memequal.
244248func hashmapIsBinaryKey (keyType types.Type ) bool {
245249 switch keyType := keyType .(type ) {
246250 case * types.Basic :
@@ -263,3 +267,87 @@ func hashmapIsBinaryKey(keyType types.Type) bool {
263267 return false
264268 }
265269}
270+
271+ func (b * builder ) zeroUndefBytes (typ types.Type , ptr llvm.Value ) error {
272+
273+ // We know that hashmapIsBinaryKey is true, so we only have to handle those types
274+
275+ // assert that we're a struct type
276+
277+ // iterate over all the fields
278+ // zero bytes if there is padding for this field
279+ // if fieled type is a compound type then recurse => need to support arrays
280+
281+ switch typ := typ .Underlying ().(type ) {
282+
283+ case * types.Basic :
284+ // no padding bytes
285+ return nil
286+ case * types.Pointer :
287+ // mo padding bytes
288+ return nil
289+
290+ case * types.Named :
291+ // zero underlying type
292+ return b .zeroUndefBytes (typ .Underlying (), ptr )
293+
294+ case * types.Array :
295+ llvmArrayType := b .getLLVMType (typ )
296+ llvmElemType := b .getLLVMType (typ .Elem ())
297+
298+ base := ptr
299+
300+ // two kinds of padding to deal with in an array:
301+ // - padding within each element
302+ // - padding between elements
303+
304+ for i := uint64 (0 ); i < uint64 (typ .Len ()); i ++ {
305+ // for each element, first clear any undef bytes in the element itself
306+ offset := b .targetData .ElementOffset (llvmArrayType , int (i ))
307+ llvmOffset := llvm .ConstInt (b .uintptrType , offset , false )
308+ ptr := b .CreateGEP (b .ctx .Int8Type (), base , []llvm.Value {llvmOffset }, "" )
309+
310+ // recursively zero any padding bytes in this element
311+ b .zeroUndefBytes (typ .Elem (), ptr )
312+
313+ // check for padding between elements
314+ // TODO(dgryski): typeSizeEqualStoreSize ?
315+ if allocSize , storeSize := b .targetData .TypeAllocSize (llvmElemType ), b .targetData .TypeStoreSize (llvmElemType ); allocSize != storeSize {
316+ n := llvm .ConstInt (b .uintptrType , allocSize - storeSize , false )
317+ llvmSize := llvm .ConstInt (b .uintptrType , storeSize , false )
318+ paddingStart := b .CreateGEP (b .ctx .Int8Type (), ptr , []llvm.Value {llvmSize }, "" )
319+ b .createRuntimeCall ("memzero" , []llvm.Value {paddingStart , n }, "" )
320+ }
321+ }
322+
323+ case * types.Struct :
324+ llvmStructType := b .getLLVMType (typ )
325+
326+ base := ptr
327+
328+ for i := 0 ; i < typ .NumFields (); i ++ {
329+ offset := b .targetData .ElementOffset (llvmStructType , int (i ))
330+
331+ llvmOffset := llvm .ConstInt (b .uintptrType , offset , false )
332+ ptr := b .CreateGEP (b .ctx .Int8Type (), base , []llvm.Value {llvmOffset }, "" )
333+
334+ // zero any undef bytes in this field
335+ fieldType := typ .Field (i ).Type ()
336+ b .zeroUndefBytes (fieldType , ptr )
337+
338+ // zero any undef bytes before the next field, if any
339+ if i < typ .NumFields ()- 1 {
340+ nextOffset := b .targetData .ElementOffset (llvmStructType , i + 1 )
341+ llvmElemType := b .getLLVMType (fieldType )
342+ if storeSize := b .targetData .TypeStoreSize (llvmElemType ); (nextOffset - offset ) != storeSize {
343+ n := llvm .ConstInt (b .uintptrType , (nextOffset - offset )- storeSize , false )
344+ llvmSize := llvm .ConstInt (b .uintptrType , storeSize , false )
345+ paddingStart := b .CreateGEP (b .ctx .Int8Type (), ptr , []llvm.Value {llvmSize }, "" )
346+ b .createRuntimeCall ("memzero" , []llvm.Value {paddingStart , n }, "" )
347+ }
348+ }
349+ }
350+ }
351+
352+ return nil
353+ }
0 commit comments