Skip to content

Commit 608bb8a

Browse files
committed
This is a third commit of Issue #2498
1 parent 6cd3a7e commit 608bb8a

File tree

1 file changed

+50
-15
lines changed

1 file changed

+50
-15
lines changed

vyper/builtins/functions.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -213,17 +213,32 @@ def _try_fold(self, node):
213213

214214
elif isinstance(target_typedef, AddressT):
215215
if isinstance(value, vy_ast.Hex):
216-
if len(value.value) != 42: # 0x + 40 hex chars
217-
raise InvalidLiteral("Address must be exactly 20 bytes", node.args[0])
218-
result = value.value
216+
if not value.value.startswith("0x"):
217+
raise InvalidLiteral("Address must start with 0x", node.args[0])
218+
try:
219+
addr_bytes = value.bytes_value
220+
if len(addr_bytes) != 20:
221+
raise InvalidLiteral("Address must be exactly 20 bytes", node.args[0])
222+
result = value.value
223+
except ValueError:
224+
raise InvalidLiteral("Invalid hex literal for address", node.args[0])
225+
elif isinstance(value, vy_ast.Int):
226+
# Convert integer to address (right-padded with zeros)
227+
addr_bytes = value.value.to_bytes(32, "big")[-20:]
228+
result = f"0x{addr_bytes.hex()}"
219229
else:
220230
raise UnfoldableNode
221231
return vy_ast.Hex.from_node(node, value=result)
222232

223233
elif isinstance(target_typedef, IntegerT):
224-
if not isinstance(value, (vy_ast.Int, vy_ast.Decimal)):
234+
if isinstance(value, vy_ast.Decimal):
235+
if value.value % 1 != 0:
236+
raise InvalidLiteral("Cannot truncate decimal when converting to integer", node.args[0])
237+
result = int(value.value)
238+
elif isinstance(value, vy_ast.Int):
239+
result = value.value
240+
else:
225241
raise UnfoldableNode
226-
result = int(value.value)
227242
lo, hi = target_typedef.ast_bounds
228243
if result < lo or result > hi:
229244
raise InvalidLiteral(
@@ -233,12 +248,30 @@ def _try_fold(self, node):
233248

234249
elif isinstance(target_typedef, BytesM_T):
235250
if isinstance(value, vy_ast.Hex):
236-
if len(value.bytes_value) != target_typedef.length:
251+
if not value.value.startswith("0x"):
252+
raise InvalidLiteral("Bytes must start with 0x", node.args[0])
253+
try:
254+
bytes_value = value.bytes_value
255+
if len(bytes_value) > target_typedef.length:
256+
raise InvalidLiteral(
257+
f"Bytes literal too long for {target_typedef}", node.args[0]
258+
)
259+
# Left-pad with zeros if shorter than target length
260+
if len(bytes_value) < target_typedef.length:
261+
bytes_value = bytes_value.rjust(target_typedef.length, b"\x00")
262+
result = f"0x{bytes_value.hex()}"
263+
except ValueError:
264+
raise InvalidLiteral("Invalid hex literal for bytes", node.args[0])
265+
elif isinstance(value, vy_ast.Bytes):
266+
bytes_value = value.value
267+
if len(bytes_value) > target_typedef.length:
237268
raise InvalidLiteral(
238-
f"Expected {target_typedef.length} bytes, got {len(value.bytes_value)}",
239-
node.args[0]
269+
f"Bytes literal too long for {target_typedef}", node.args[0]
240270
)
241-
result = value.value
271+
# Left-pad with zeros if shorter than target length
272+
if len(bytes_value) < target_typedef.length:
273+
bytes_value = bytes_value.rjust(target_typedef.length, b"\x00")
274+
result = f"0x{bytes_value.hex()}"
242275
else:
243276
raise UnfoldableNode
244277
return vy_ast.Hex.from_node(node, value=result)
@@ -247,15 +280,13 @@ def _try_fold(self, node):
247280

248281
def fetch_call_return(self, node):
249282
_, target_typedef = self.infer_arg_types(node)
250-
251-
# note: more type conversion validation happens in convert.py
252283
return target_typedef.typedef
253284

254285
def infer_arg_types(self, node, expected_return_typ=None):
255286
self._validate_arg_types(node)
256287
possible = sorted(
257288
get_possible_types_from_node(node.args[0]),
258-
key=lambda t: (t.typ, getattr(t, "bits", 0))
289+
key=lambda t: (str(t.typ), getattr(t, "bits", 0))
259290
)
260291
value_type = possible[0]
261292
target_type = type_from_annotation(node.args[1])
@@ -419,7 +450,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
419450
self._validate_arg_types(node)
420451
possible = sorted(
421452
get_possible_types_from_node(node.args[0]),
422-
key=lambda t: (t.typ, getattr(t, "bits", 0))
453+
key=lambda t: (str(t.typ), getattr(t, "bits", 0))
423454
)
424455
b_type = possible[0]
425456
return [b_type, self._inputs[1][1], self._inputs[2][1]]
@@ -950,11 +981,15 @@ def _try_fold(self, node):
950981
raise UnfoldableNode
951982

952983
if isinstance(output_type, BytesM_T):
984+
expected = output_type.length
985+
if expected != 32:
986+
result = result[:expected]
953987
return vy_ast.Hex.from_node(node, value=f"0x{result.hex()}")
954988
elif isinstance(output_type, IntegerT):
955989
return vy_ast.Int.from_node(node, value=int.from_bytes(result, "big"))
956990
elif isinstance(output_type, AddressT):
957-
return vy_ast.Hex.from_node(node, value=f"0x{result.hex()}")
991+
# right-align as per ABI: take the last 20 bytes
992+
return vy_ast.Hex.from_node(node, value=f"0x{result[-20:].hex()}")
958993
else:
959994
raise UnfoldableNode
960995

@@ -967,7 +1002,7 @@ def infer_arg_types(self, node, expected_return_typ=None):
9671002
self._validate_arg_types(node)
9681003
possible = sorted(
9691004
get_possible_types_from_node(node.args[0]),
970-
key=lambda t: (t.typ, getattr(t, "bits", 0))
1005+
key=lambda t: (str(t.typ), getattr(t, "bits", 0))
9711006
)
9721007
input_type = possible[0]
9731008
return [input_type, UINT256_T]

0 commit comments

Comments
 (0)