Skip to content

Commit 60d64d6

Browse files
committed
ENH: torch: add type promotion for (uintN, uintM)
1 parent 84390f5 commit 60d64d6

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

array_api_compat/torch/_aliases.py

+46
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,52 @@
9696
}
9797

9898

99+
try:
100+
# torch >=2.3
101+
_uint_promotion_table = {
102+
# uints
103+
(torch.uint8, torch.uint16): torch.uint16,
104+
(torch.uint8, torch.uint32): torch.uint32,
105+
(torch.uint8, torch.uint64): torch.uint64,
106+
(torch.uint16, torch.uint8): torch.uint16,
107+
(torch.uint16, torch.uint16): torch.uint16,
108+
(torch.uint16, torch.uint32): torch.uint32,
109+
(torch.uint16, torch.uint64): torch.uint64,
110+
(torch.uint32, torch.uint8): torch.uint32,
111+
(torch.uint32, torch.uint16): torch.uint32,
112+
(torch.uint32, torch.uint32): torch.uint32,
113+
(torch.uint32, torch.uint64): torch.uint64,
114+
(torch.uint64, torch.uint8): torch.uint64,
115+
(torch.uint64, torch.uint16): torch.uint64,
116+
(torch.uint64, torch.uint32): torch.uint64,
117+
(torch.uint64, torch.uint64): torch.uint64,
118+
# ints and uints (mixed sign)
119+
(torch.int8, torch.uint16): torch.int32,
120+
(torch.int8, torch.uint32): torch.int64,
121+
(torch.int16, torch.uint8): torch.int16,
122+
(torch.int16, torch.uint16): torch.int32,
123+
(torch.int16, torch.uint32): torch.int64,
124+
(torch.int32, torch.uint8): torch.int32,
125+
(torch.int32, torch.uint16): torch.int32,
126+
(torch.int32, torch.uint32): torch.int64,
127+
(torch.int64, torch.uint8): torch.int64,
128+
(torch.int64, torch.uint16): torch.int64,
129+
(torch.int64, torch.uint32): torch.int64,
130+
(torch.uint16, torch.int8): torch.int32,
131+
(torch.uint16, torch.int16): torch.int32,
132+
(torch.uint16, torch.int32): torch.int32,
133+
(torch.uint16, torch.int64): torch.int64,
134+
(torch.uint32, torch.int8): torch.int64,
135+
(torch.uint32, torch.int16): torch.int64,
136+
(torch.uint32, torch.int32): torch.int64,
137+
(torch.uint32, torch.int64): torch.int64,
138+
}
139+
except AttributeError:
140+
pass
141+
142+
_promotion_table.update(**_uint_promotion_table)
143+
144+
99145
def _two_arg(f):
100146
@_wraps(f)
101147
def _f(x1, x2, /, **kwargs):

0 commit comments

Comments
 (0)