Skip to content

Commit 4747ca8

Browse files
committed
[PyRTG] Add method to get random integer
1 parent bf5da55 commit 4747ca8

File tree

2 files changed

+31
-3
lines changed

2 files changed

+31
-3
lines changed

frontends/PyRTG/src/pyrtg/integers.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from .circt import ir
88
from .core import Value
99
from .index import index
10+
from .rtg import rtg
1011

11-
import typing
12+
from typing import Union
1213

1314

1415
class Integer(Value):
@@ -19,13 +20,26 @@ class Integer(Value):
1920
away during randomization.
2021
"""
2122

22-
def __init__(self, value: typing.Union[ir.Value, int]) -> Integer:
23+
def __init__(self, value: Union[ir.Value, int]) -> Integer:
2324
"""
2425
Use this constructor to create an Integer from a builtin Python int.
2526
"""
2627

2728
self._value = value
2829

30+
def random(lower_bound: Union[int, Integer],
31+
upper_bound: Union[int, Integer]) -> Integer:
32+
"""
33+
Get a random number in the given range (lower inclusive, upper exclusive).
34+
"""
35+
36+
if isinstance(lower_bound, int):
37+
lower_bound = Integer(lower_bound)
38+
if isinstance(upper_bound, int):
39+
upper_bound = Integer(upper_bound)
40+
41+
return rtg.RandomNumberInRangeOp(lower_bound, upper_bound)
42+
2943
def __add__(self, other: Integer) -> Integer:
3044
return index.AddOp(self._get_ssa_value(), other._get_ssa_value())
3145

@@ -84,7 +98,7 @@ class Bool(Value):
8498
away during randomization.
8599
"""
86100

87-
def __init__(self, value: typing.Union[ir.Value, bool]) -> Bool:
101+
def __init__(self, value: Union[ir.Value, bool]) -> Bool:
88102
"""
89103
Use this constructor to create a Bool from a builtin Python bool.
90104
"""

frontends/PyRTG/test/basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,3 +402,17 @@ def test7_bools(a, b):
402402
consumer(a > b)
403403
consumer(a <= b)
404404
consumer(a >= b)
405+
406+
407+
# MLIR-LABEL: rtg.test @test8_random_integer
408+
# MLIR-NEXT: rtg.random_number_in_range [%a, %b)
409+
410+
411+
@sequence(Integer.type())
412+
def int_consumer(b):
413+
pass
414+
415+
416+
@test(("a", Integer.type()), ("b", Integer.type()))
417+
def test8_random_integer(a, b):
418+
int_consumer(Integer.random(a, b))

0 commit comments

Comments
 (0)