Skip to content

Commit 9acafaa

Browse files
authored
Merge branch 'main' into check_vector_size_for_the_power_of_2
2 parents a1dda5a + 0ed9251 commit 9acafaa

File tree

3 files changed

+355
-50
lines changed

3 files changed

+355
-50
lines changed

python/cudaq/kernel/ast_bridge.py

+70-50
Original file line numberDiff line numberDiff line change
@@ -3374,76 +3374,96 @@ def visit_Compare(self, node):
33743374
self.visit(node.left)
33753375
left = self.popValue()
33763376
self.visit(node.comparators[0])
3377-
comparator = self.popValue()
3377+
right = self.popValue()
33783378
op = node.ops[0]
33793379

3380-
if isinstance(op, ast.Gt):
3381-
if IntegerType.isinstance(left.type):
3382-
if F64Type.isinstance(comparator.type):
3383-
self.emitFatalError(
3384-
"invalid rhs for comparison (f64 type and not i64 type).",
3385-
node)
3380+
left_type = left.type
3381+
right_type = right.type
3382+
3383+
if IntegerType.isinstance(left_type) and F64Type.isinstance(right_type):
3384+
left = arith.SIToFPOp(self.getFloatType(), left).result
3385+
elif F64Type.isinstance(left_type) and IntegerType.isinstance(
3386+
right_type):
3387+
right = arith.SIToFPOp(self.getFloatType(), right).result
3388+
elif IntegerType.isinstance(left_type) and IntegerType.isinstance(
3389+
right_type):
3390+
if IntegerType(left_type).width < IntegerType(right_type).width:
3391+
zeroext = IntegerType(left_type).width == 1
3392+
left = cc.CastOp(right_type,
3393+
left,
3394+
sint=not zeroext,
3395+
zint=zeroext).result
3396+
elif IntegerType(left_type).width > IntegerType(right_type).width:
3397+
zeroext = IntegerType(right_type).width == 1
3398+
right = cc.CastOp(left_type,
3399+
right,
3400+
sint=not zeroext,
3401+
zint=zeroext).result
33863402

3387-
self.pushValue(
3388-
arith.CmpIOp(self.getIntegerAttr(iTy, 4), left,
3389-
comparator).result)
3390-
elif F64Type.isinstance(left.type):
3391-
if IntegerType.isinstance(comparator.type):
3392-
comparator = arith.SIToFPOp(self.getFloatType(),
3393-
comparator).result
3403+
if isinstance(op, ast.Gt):
3404+
if F64Type.isinstance(left.type):
33943405
self.pushValue(
33953406
arith.CmpFOp(self.getIntegerAttr(iTy, 2), left,
3396-
comparator).result)
3407+
right).result)
3408+
else:
3409+
self.pushValue(
3410+
arith.CmpIOp(self.getIntegerAttr(iTy, 4), left,
3411+
right).result)
33973412
return
33983413

33993414
if isinstance(op, ast.GtE):
3400-
self.pushValue(
3401-
arith.CmpIOp(self.getIntegerAttr(iTy, 5), left,
3402-
comparator).result)
3415+
if F64Type.isinstance(left.type):
3416+
self.pushValue(
3417+
arith.CmpFOp(self.getIntegerAttr(iTy, 3), left,
3418+
right).result)
3419+
else:
3420+
self.pushValue(
3421+
arith.CmpIOp(self.getIntegerAttr(iTy, 5), left,
3422+
right).result)
34033423
return
34043424

34053425
if isinstance(op, ast.Lt):
3406-
self.pushValue(
3407-
arith.CmpIOp(self.getIntegerAttr(iTy, 2), left,
3408-
comparator).result)
3426+
if F64Type.isinstance(left.type):
3427+
self.pushValue(
3428+
arith.CmpFOp(self.getIntegerAttr(iTy, 4), left,
3429+
right).result)
3430+
else:
3431+
self.pushValue(
3432+
arith.CmpIOp(self.getIntegerAttr(iTy, 2), left,
3433+
right).result)
34093434
return
34103435

34113436
if isinstance(op, ast.LtE):
3412-
self.pushValue(
3413-
arith.CmpIOp(self.getIntegerAttr(iTy, 7), left,
3414-
comparator).result)
3437+
if F64Type.isinstance(left.type):
3438+
self.pushValue(
3439+
arith.CmpFOp(self.getIntegerAttr(iTy, 5), left,
3440+
right).result)
3441+
else:
3442+
self.pushValue(
3443+
arith.CmpIOp(self.getIntegerAttr(iTy, 7), left,
3444+
right).result)
34153445
return
34163446

34173447
if isinstance(op, ast.NotEq):
3418-
if F64Type.isinstance(left.type) and IntegerType.isinstance(
3419-
comparator.type):
3420-
left = arith.FPToSIOp(comparator.type, left).result
3421-
if IntegerType(left.type).width < IntegerType(
3422-
comparator.type).width:
3423-
zeroext = IntegerType(left.type).width == 1
3424-
left = cc.CastOp(comparator.type,
3425-
left,
3426-
sint=not zeroext,
3427-
zint=zeroext).result
3428-
self.pushValue(
3429-
arith.CmpIOp(self.getIntegerAttr(iTy, 1), left,
3430-
comparator).result)
3448+
if F64Type.isinstance(left.type):
3449+
self.pushValue(
3450+
arith.CmpFOp(self.getIntegerAttr(iTy, 6), left,
3451+
right).result)
3452+
else:
3453+
self.pushValue(
3454+
arith.CmpIOp(self.getIntegerAttr(iTy, 1), left,
3455+
right).result)
34313456
return
34323457

34333458
if isinstance(op, ast.Eq):
3434-
if F64Type.isinstance(left.type) and IntegerType.isinstance(
3435-
comparator.type):
3436-
left = arith.FPToSIOp(comparator.type, left).result
3437-
if IntegerType(left.type).width < IntegerType(
3438-
comparator.type).width:
3439-
zeroext = IntegerType(left.type).width == 1
3440-
left = cc.CastOp(comparator.type,
3441-
left,
3442-
sint=not zeroext,
3443-
zint=zeroext).result
3444-
self.pushValue(
3445-
arith.CmpIOp(self.getIntegerAttr(iTy, 0), left,
3446-
comparator).result)
3459+
if F64Type.isinstance(left.type):
3460+
self.pushValue(
3461+
arith.CmpFOp(self.getIntegerAttr(iTy, 1), left,
3462+
right).result)
3463+
else:
3464+
self.pushValue(
3465+
arith.CmpIOp(self.getIntegerAttr(iTy, 0), left,
3466+
right).result)
34473467
return
34483468

34493469
def visit_AugAssign(self, node):
+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# ============================================================================ #
2+
# Copyright (c) 2022 - 2025 NVIDIA Corporation & Affiliates. #
3+
# All rights reserved. #
4+
# #
5+
# This source code and the accompanying materials are made available under #
6+
# the terms of the Apache License 2.0 which accompanies this distribution. #
7+
# ============================================================================ #
8+
9+
import os
10+
import pytest
11+
import cudaq
12+
13+
14+
def cmpfop(predicate, left, right):
15+
operations = {
16+
2: lambda l, r: l > r,
17+
3: lambda l, r: l >= r,
18+
4: lambda l, r: l < r,
19+
5: lambda l, r: l <= r,
20+
1: lambda l, r: l == r,
21+
6: lambda l, r: l != r,
22+
}
23+
return operations[predicate](left, right)
24+
25+
26+
def cmpiop(predicate, left, right):
27+
operations = {
28+
4: lambda l, r: l > r,
29+
5: lambda l, r: l >= r,
30+
2: lambda l, r: l < r,
31+
7: lambda l, r: l <= r,
32+
0: lambda l, r: l == r,
33+
1: lambda l, r: l != r,
34+
}
35+
return operations[predicate](left, right)
36+
37+
38+
@pytest.mark.parametrize(
39+
"left, right, operation, expected",
40+
[
41+
# Integer comparisons
42+
(3, 5, "Lt", True),
43+
(5, 3, "Gt", True),
44+
(3, 3, "Eq", True),
45+
(3, 5, "LtE", True),
46+
(5, 5, "GtE", True),
47+
(3, 5, "NotEq", True),
48+
(5, 5, "NotEq", False),
49+
50+
# Float comparisons
51+
(3.2, 4.5, "Lt", True),
52+
(4.5, 3.2, "Gt", True),
53+
(3.2, 3.2, "Eq", True),
54+
(3.2, 4.5, "LtE", True),
55+
(4.5, 4.5, "GtE", True),
56+
(3.2, 4.5, "NotEq", True),
57+
(4.5, 4.5, "NotEq", False),
58+
59+
# Mixed comparisons
60+
(3, 4.5, "Lt", True),
61+
(4.5, 3, "Gt", True),
62+
(3, 3.0, "Eq", True),
63+
(3, 4.5, "LtE", True),
64+
(4.5, 4, "GtE", True),
65+
(3, 4.5, "NotEq", True),
66+
],
67+
)
68+
def test_visit_compare(left, right, operation, expected):
69+
result = None
70+
71+
if operation in ["Gt", "GtE", "Lt", "LtE", "Eq", "NotEq"]:
72+
if isinstance(left, float) or isinstance(right, float):
73+
predicate = {
74+
"Gt": 2,
75+
"GtE": 3,
76+
"Lt": 4,
77+
"LtE": 5,
78+
"Eq": 1,
79+
"NotEq": 6,
80+
}[operation]
81+
result = cmpfop(predicate, left, right)
82+
else:
83+
predicate = {
84+
"Gt": 4,
85+
"GtE": 5,
86+
"Lt": 2,
87+
"LtE": 7,
88+
"Eq": 0,
89+
"NotEq": 1,
90+
}[operation]
91+
result = cmpiop(predicate, left, right)
92+
93+
assert result == expected
94+
95+
96+
if __name__ == "__main__":
97+
loc = os.path.abspath(__file__)
98+
pytest.main([loc, "-rP"])

0 commit comments

Comments
 (0)