Skip to content

Commit e463acd

Browse files
fix[codegen]: fix removal of side effects in concat (#4644)
concat would remove side effects for zero-length arguments. fix by removing the fastpath. as the test case shows, this pattern is not common in user-code. this commit also reverts the optimization introduced in a06f7df, and replaces it with a tag on the `IRnode`, since otherwise the `~empty` intrinsic can show up in the IR at the time we lower to assembly (which is an error). --------- Co-authored-by: cyberthirst <[email protected]>
1 parent a06f7df commit e463acd

File tree

4 files changed

+57
-10
lines changed

4 files changed

+57
-10
lines changed

tests/functional/builtins/codegen/test_concat.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,48 @@ def hoo(x: bytes32, y: bytes32) -> Bytes[64]:
170170
print("Passed second concat tests")
171171

172172

173+
def test_concat_zero_length_side_effects(get_contract):
174+
code = """
175+
counter: public(uint256)
176+
177+
@external
178+
def test() -> Bytes[256]:
179+
a: Bytes[256] = concat(b"" if self.sideeffect() else b"", b"aaaa")
180+
return a
181+
182+
def sideeffect() -> bool:
183+
self.counter += 1
184+
return True
185+
"""
186+
187+
c = get_contract(code)
188+
189+
assert c.counter() == 0
190+
assert c.test() == b"aaaa"
191+
assert c.counter() == 1
192+
193+
194+
def test_concat_zero_length_side_effects2(get_contract):
195+
code = """
196+
counter: public(uint256)
197+
198+
@external
199+
def test() -> Bytes[256]:
200+
a: Bytes[256] = concat(b"" if self.sideeffect() else b"", b"")
201+
return a
202+
203+
def sideeffect() -> bool:
204+
self.counter += 1
205+
return True
206+
"""
207+
208+
c = get_contract(code)
209+
210+
assert c.counter() == 0
211+
assert c.test() == b""
212+
assert c.counter() == 1
213+
214+
173215
def test_small_output(get_contract):
174216
code = """
175217
@external

vyper/builtins/functions.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,10 +561,6 @@ def build_IR(self, expr, context):
561561
dst_data = add_ofst(bytes_data_ptr(dst), ofst)
562562

563563
if isinstance(arg.typ, _BytestringT):
564-
# Ignore empty strings
565-
if arg.typ.maxlen == 0:
566-
continue
567-
568564
with arg.cache_when_complex("arg") as (b1, arg):
569565
argdata = bytes_data_ptr(arg)
570566

vyper/codegen/expr.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,6 @@ def _make_bytelike(cls, context, typeclass, bytez):
147147
bytez_length = len(bytez)
148148
btype = typeclass(bytez_length)
149149

150-
if bytez_length == 0:
151-
# optimization: handled specially by make_byte_array_copier
152-
return IRnode.from_list("~empty", typ=btype, annotation=f"empty {btype}")
153-
154150
placeholder = context.new_internal_variable(btype)
155151
seq = []
156152
seq.append(["mstore", placeholder, bytez_length])
@@ -163,12 +159,14 @@ def _make_bytelike(cls, context, typeclass, bytez):
163159
]
164160
)
165161

166-
return IRnode.from_list(
162+
ret = IRnode.from_list(
167163
["seq"] + seq + [placeholder],
168164
typ=btype,
169165
location=MEMORY,
170166
annotation=f"Create {btype}: {bytez}",
171167
)
168+
ret.is_source_bytes_literal = True
169+
return ret
172170

173171
# True, False, None constants
174172
def parse_NameConstant(self):

vyper/codegen/ir_node.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from vyper.evm.address_space import AddrSpace
1111
from vyper.evm.opcodes import get_ir_opcodes
1212
from vyper.exceptions import CodegenPanic, CompilerPanic
13-
from vyper.semantics.types import VyperType
13+
from vyper.semantics.types import VyperType, _BytestringT
1414
from vyper.utils import VALID_IR_MACROS, ceil32
1515

1616
# Set default string representation for ints in IR output.
@@ -139,6 +139,10 @@ class IRnode:
139139
func_ir: Any
140140
common_ir: Any
141141

142+
# for bytestrings, if we have `is_source_bytes_literal`, we can perform
143+
# certain optimizations like eliding the copy.
144+
is_source_bytes_literal: bool = False
145+
142146
def __init__(
143147
self,
144148
value: Union[str, int],
@@ -367,6 +371,13 @@ def gas(self):
367371
def is_empty_intrinsic(self):
368372
if self.value == "~empty":
369373
return True
374+
if (
375+
self.is_source_bytes_literal
376+
and isinstance(self.typ, _BytestringT)
377+
and self.typ.maxlen == 0
378+
):
379+
# special optimization case for empty `b""` literal
380+
return True
370381
if self.value == "seq":
371382
return len(self.args) == 1 and self.args[0].is_empty_intrinsic
372383
return False

0 commit comments

Comments
 (0)