|
96 | 96 | }
|
97 | 97 |
|
98 | 98 |
|
| 99 | +try: |
| 100 | + # torch >=2.3 |
| 101 | + _uint_promotion_table = { |
| 102 | + # uints |
| 103 | + (torch.uint8, torch.uint16): torch.uint16, |
| 104 | + (torch.uint8, torch.uint32): torch.uint32, |
| 105 | + (torch.uint8, torch.uint64): torch.uint64, |
| 106 | + (torch.uint16, torch.uint8): torch.uint16, |
| 107 | + (torch.uint16, torch.uint16): torch.uint16, |
| 108 | + (torch.uint16, torch.uint32): torch.uint32, |
| 109 | + (torch.uint16, torch.uint64): torch.uint64, |
| 110 | + (torch.uint32, torch.uint8): torch.uint32, |
| 111 | + (torch.uint32, torch.uint16): torch.uint32, |
| 112 | + (torch.uint32, torch.uint32): torch.uint32, |
| 113 | + (torch.uint32, torch.uint64): torch.uint64, |
| 114 | + (torch.uint64, torch.uint8): torch.uint64, |
| 115 | + (torch.uint64, torch.uint16): torch.uint64, |
| 116 | + (torch.uint64, torch.uint32): torch.uint64, |
| 117 | + (torch.uint64, torch.uint64): torch.uint64, |
| 118 | + # ints and uints (mixed sign) |
| 119 | + (torch.int8, torch.uint16): torch.int32, |
| 120 | + (torch.int8, torch.uint32): torch.int64, |
| 121 | + (torch.int16, torch.uint8): torch.int16, |
| 122 | + (torch.int16, torch.uint16): torch.int32, |
| 123 | + (torch.int16, torch.uint32): torch.int64, |
| 124 | + (torch.int32, torch.uint8): torch.int32, |
| 125 | + (torch.int32, torch.uint16): torch.int32, |
| 126 | + (torch.int32, torch.uint32): torch.int64, |
| 127 | + (torch.int64, torch.uint8): torch.int64, |
| 128 | + (torch.int64, torch.uint16): torch.int64, |
| 129 | + (torch.int64, torch.uint32): torch.int64, |
| 130 | + (torch.uint16, torch.int8): torch.int32, |
| 131 | + (torch.uint16, torch.int16): torch.int32, |
| 132 | + (torch.uint16, torch.int32): torch.int32, |
| 133 | + (torch.uint16, torch.int64): torch.int64, |
| 134 | + (torch.uint32, torch.int8): torch.int64, |
| 135 | + (torch.uint32, torch.int16): torch.int64, |
| 136 | + (torch.uint32, torch.int32): torch.int64, |
| 137 | + (torch.uint32, torch.int64): torch.int64, |
| 138 | +} |
| 139 | +except AttributeError: |
| 140 | + pass |
| 141 | + |
| 142 | +_promotion_table.update(**_uint_promotion_table) |
| 143 | + |
| 144 | + |
99 | 145 | def _two_arg(f):
|
100 | 146 | @_wraps(f)
|
101 | 147 | def _f(x1, x2, /, **kwargs):
|
|
0 commit comments