Skip to content

Commit a06f7df

Browse files
fix[codegen]: fix bytes copying routines for 0-length case (#4649)
per title. as in the referenced pull requests and security advisories, it is difficult, but not impossible, to construct an empty bytestring which has side effects. in this commit, don't fastpath out the side effects for zero-length bytestrings references: - #4644 - GHSA-qhr6-mgqr-mchm - #4645 - GHSA-3vcg-j39x-cwfm --------- Co-authored-by: cyberthirst <[email protected]>
1 parent 1fdfd69 commit a06f7df

File tree

6 files changed

+103
-25
lines changed

6 files changed

+103
-25
lines changed

tests/functional/builtins/codegen/test_empty.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

3-
from vyper.exceptions import InstantiationException, TypeMismatch
3+
from vyper.compiler import compile_code
4+
from vyper.exceptions import ArrayIndexException, InstantiationException, TypeMismatch
45

56

67
@pytest.mark.parametrize(
@@ -692,14 +693,39 @@ def foo():
692693

693694

694695
@pytest.mark.parametrize(
695-
"contract",
696+
"code, exc",
696697
[
697-
"""
698+
(
699+
"""
698700
@external
699701
def test():
700702
a: uint256 = empty(HashMap[uint256, uint256])[0]
701-
"""
703+
""",
704+
InstantiationException,
705+
),
706+
(
707+
"""
708+
@external
709+
def test():
710+
a: Bytes[32] = empty(Bytes[0])
711+
""",
712+
ArrayIndexException,
713+
),
702714
],
703715
)
704-
def test_invalid_types(contract, get_contract, assert_compile_failed):
705-
assert_compile_failed(lambda: get_contract(contract), InstantiationException)
716+
def test_invalid_types(code, exc):
717+
with pytest.raises(exc):
718+
compile_code(code)
719+
720+
721+
@pytest.mark.parametrize("empty_bytes", ["x''", "b''"])
722+
@pytest.mark.parametrize("size", [1] + [i for i in range(1 * 32, 5 * 32, 32)])
723+
def test_empty_Bytes(get_contract, size, empty_bytes):
724+
code = f"""
725+
@external
726+
def foo() -> bool:
727+
b: Bytes[{size}] = empty(Bytes[{size}])
728+
return b == {empty_bytes}
729+
"""
730+
c = get_contract(code)
731+
assert c.foo() is True

tests/functional/builtins/codegen/test_slice.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,3 +633,37 @@ def foo() -> Bytes[96]:
633633

634634
c = get_contract(slice_code)
635635
assert c.foo() == b"defghijklmnopqrstuvwxyz123456789"
636+
637+
638+
def test_slice_empty_bytes32(get_contract):
639+
code = """
640+
@external
641+
def bar() -> Bytes[32]:
642+
return slice(empty(bytes32), 0, 32)
643+
"""
644+
c = get_contract(code)
645+
assert c.bar() == b"\x00" * 32
646+
647+
648+
def test_slice_empty_Bytes32_0(get_contract, tx_failed):
649+
code = """
650+
@external
651+
def bar(length: uint256) -> Bytes[32]:
652+
# use variable length otherwise it gets optimized to
653+
# StaticAssertionException
654+
return slice(empty(Bytes[32]), 0, length)
655+
"""
656+
c = get_contract(code)
657+
with tx_failed():
658+
_ = c.bar(1)
659+
660+
661+
def test_slice_empty_Bytes32_1(get_contract):
662+
code = """
663+
@external
664+
def bar() -> Bytes[32]:
665+
length: uint256 = 0
666+
return slice(empty(Bytes[32]), 0, length)
667+
"""
668+
c = get_contract(code)
669+
assert c.bar() == b""

vyper/builtins/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def build_IR(self, expr, args, kwargs, context):
356356
if src.location is None:
357357
# it's not a pointer; force it to be one since
358358
# copy_bytes works on pointers.
359-
assert is_bytes32, src
359+
assert is_bytes32 or src.is_empty_intrinsic, src
360360
src = ensure_in_memory(src, context)
361361
elif potential_overlap(src, start) or potential_overlap(src, length):
362362
src = create_memory_copy(src, context)

vyper/codegen/core.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -185,12 +185,12 @@ def make_byte_array_copier(dst, src):
185185

186186
_check_assign_bytes(dst, src)
187187

188-
# TODO: remove this branch, copy_bytes and get_bytearray_length should handle
189-
if src.value == "~empty" or src.typ.maxlen == 0:
190-
# set length word to 0.
191-
return STORE(dst, 0)
192-
193188
with src.cache_when_complex("src") as (b1, src):
189+
if src.typ.maxlen == 0 or src.is_empty_intrinsic:
190+
# set dst length to zero, preserving side effects of `src`.
191+
ret = STORE(dst, 0)
192+
return b1.resolve(ret)
193+
194194
if src.typ.maxlen <= 32 and not copy_opcode_available(dst, src):
195195
# if there is no batch copy opcode available,
196196
# it's cheaper to run two load/stores instead of copy_bytes
@@ -281,7 +281,7 @@ def _dynarray_make_setter(dst, src, hi=None):
281281
assert isinstance(src.typ, DArrayT)
282282
assert isinstance(dst.typ, DArrayT)
283283

284-
if src.value == "~empty":
284+
if src.is_empty_intrinsic:
285285
return IRnode.from_list(STORE(dst, 0))
286286

287287
# copy contents of src dynarray to dst.
@@ -390,10 +390,12 @@ def copy_bytes(dst, src, length, length_bound):
390390

391391
# correctness: do not clobber dst
392392
if length_bound == 0:
393-
return IRnode.from_list(["seq"], annotation=annotation)
393+
ret = IRnode.from_list(["seq"], annotation=annotation)
394+
return b1.resolve(b2.resolve(b3.resolve(ret)))
394395
# performance: if we know that length is 0, do not copy anything
395396
if length.value == 0:
396-
return IRnode.from_list(["seq"], annotation=annotation)
397+
ret = IRnode.from_list(["seq"], annotation=annotation)
398+
return b1.resolve(b2.resolve(b3.resolve(ret)))
397399

398400
assert src.is_pointer and dst.is_pointer
399401

@@ -457,7 +459,7 @@ def get_bytearray_length(arg):
457459

458460
# TODO: it would be nice to merge the implementations of get_bytearray_length and
459461
# get_dynarray_count
460-
if arg.value == "~empty":
462+
if arg.is_empty_intrinsic:
461463
return IRnode.from_list(0, typ=typ)
462464

463465
return IRnode.from_list(LOAD(arg), typ=typ)
@@ -472,7 +474,7 @@ def get_dyn_array_count(arg):
472474
if arg.value == "multi":
473475
return IRnode.from_list(len(arg.args), typ=typ)
474476

475-
if arg.value == "~empty":
477+
if arg.is_empty_intrinsic:
476478
# empty(DynArray[...])
477479
return IRnode.from_list(0, typ=typ)
478480

@@ -586,7 +588,7 @@ def _get_element_ptr_tuplelike(parent, key, hi=None):
586588
annotation = None
587589

588590
# generated by empty() + make_setter
589-
if parent.value == "~empty":
591+
if parent.is_empty_intrinsic:
590592
return IRnode.from_list("~empty", typ=subtype)
591593

592594
if parent.value == "multi":
@@ -635,7 +637,7 @@ def _get_element_ptr_array(parent, key, array_bounds_check):
635637

636638
subtype = parent.typ.value_type
637639

638-
if parent.value == "~empty":
640+
if parent.is_empty_intrinsic:
639641
if array_bounds_check:
640642
# this case was previously missing a bounds check. codegen
641643
# is a bit complicated when bounds check is required, so
@@ -765,7 +767,7 @@ def unwrap_location(orig):
765767
return IRnode.from_list(LOAD(orig), typ=orig.typ)
766768
else:
767769
# CMC 2022-03-24 TODO refactor so this branch can be removed
768-
if orig.value == "~empty":
770+
if orig.is_empty_intrinsic:
769771
# must be word type
770772
return IRnode.from_list(0, typ=orig.typ)
771773
return orig
@@ -821,13 +823,14 @@ def dummy_node_for_type(typ):
821823
return IRnode("fake_node", typ=typ)
822824

823825

824-
def _check_assign_bytes(left, right):
825-
if right.typ.maxlen > left.typ.maxlen: # pragma: nocover
826+
def _check_assign_bytes(left, right): # pragma: nocover
827+
if right.typ.maxlen > left.typ.maxlen:
826828
raise TypeMismatch(f"Cannot cast from {right.typ} to {left.typ}")
827829

828830
# stricter check for zeroing a byte array.
829831
# TODO: these should be TypeCheckFailure instead of TypeMismatch
830-
if right.value == "~empty" and right.typ.maxlen != left.typ.maxlen: # pragma: nocover
832+
rlen = right.typ.maxlen
833+
if right.is_empty_intrinsic and rlen != 0 and rlen != left.typ.maxlen:
831834
raise TypeMismatch(f"Cannot cast from empty({right.typ}) to {left.typ}")
832835

833836

@@ -858,7 +861,7 @@ def FAIL(): # pragma: no cover
858861
FAIL()
859862

860863
# stricter check for zeroing
861-
if right.value == "~empty" and right.typ.count != left.typ.count: # pragma: nocover
864+
if right.is_empty_intrinsic and right.typ.count != left.typ.count: # pragma: nocover
862865
raise TypeCheckFailure(
863866
f"Bad type for clearing bytes: expected {left.typ} but got {right.typ}"
864867
)
@@ -1109,7 +1112,7 @@ def copy_opcode_available(left, right):
11091112

11101113

11111114
def _complex_make_setter(left, right, hi=None):
1112-
if right.value == "~empty" and left.location == MEMORY:
1115+
if right.is_empty_intrinsic and left.location == MEMORY:
11131116
# optimized memzero
11141117
return mzero(left, left.typ.memory_bytes_required)
11151118

vyper/codegen/expr.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def parse_HexBytes(self):
146146
def _make_bytelike(cls, context, typeclass, bytez):
147147
bytez_length = len(bytez)
148148
btype = typeclass(bytez_length)
149+
150+
if bytez_length == 0:
151+
# optimization: handled specially by make_byte_array_copier
152+
return IRnode.from_list("~empty", typ=btype, annotation=f"empty {btype}")
153+
149154
placeholder = context.new_internal_variable(btype)
150155
seq = []
151156
seq.append(["mstore", placeholder, bytez_length])
@@ -157,6 +162,7 @@ def _make_bytelike(cls, context, typeclass, bytez):
157162
bytes_to_int((bytez + b"\x00" * 31)[i : i + 32]),
158163
]
159164
)
165+
160166
return IRnode.from_list(
161167
["seq"] + seq + [placeholder],
162168
typ=btype,

vyper/codegen/ir_node.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,14 @@ def __deepcopy__(self, memo):
363363
def gas(self):
364364
return self._gas + self.add_gas_estimate
365365

366+
@property
367+
def is_empty_intrinsic(self):
368+
if self.value == "~empty":
369+
return True
370+
if self.value == "seq":
371+
return len(self.args) == 1 and self.args[0].is_empty_intrinsic
372+
return False
373+
366374
# the IR should be cached and/or evaluated exactly once
367375
@property
368376
def is_complex_ir(self):
@@ -376,6 +384,7 @@ def is_complex_ir(self):
376384
isinstance(self.value, str)
377385
and (self.value.lower() in VALID_IR_MACROS or self.value.upper() in get_ir_opcodes())
378386
and self.value.lower() not in do_not_cache
387+
and not self.is_empty_intrinsic
379388
)
380389

381390
# set an error message and push down to its children that don't have error_msg set

0 commit comments

Comments
 (0)