Skip to content

Commit e51af25

Browse files
authored
[Kernel] add heuristic gmm block sizes choosing logic (#9289)
1 parent 275f6e9 commit e51af25

File tree

1 file changed

+94
-27
lines changed

1 file changed

+94
-27
lines changed

torch_xla/experimental/custom_kernel.py

Lines changed: 94 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,46 +1306,120 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
13061306
return res
13071307

13081308

1309+
def _round_up_to_multiple_of_128_within_limit(x: int, limit: int) -> int:
1310+
"""
1311+
Rounds the given integer `x` up to the nearest multiple of 128, without exceeding
1312+
the specified `limit`.
1313+
1314+
If `x` is less than or equal to 128, returns 128.
1315+
If `x` is less than `limit`, returns the smallest multiple of 128 greater than or
1316+
equal to `x`.
1317+
If `x` is greater than or equal to `limit`, searches for the largest multiple of
1318+
128 less than or equal to `limit` (down to 512) that divides `x` evenly, and
1319+
returns it.
1320+
If no such candidate is found, returns `limit`.
1321+
1322+
Args:
1323+
x (int): The integer to round up.
1324+
limit (int): The upper bound (must be a multiple of 128 and at least 128).
1325+
1326+
Returns:
1327+
int: The rounded value according to the rules above.
1328+
1329+
Raises:
1330+
AssertionError: If `limit` is less than 128 or not a multiple of 128.
1331+
"""
1332+
assert limit >= 128 and limit % 128 == 0
1333+
if x <= 128:
1334+
return 128
1335+
if x < limit:
1336+
return (x + 127) // 128 * 128
1337+
for candidate in range(limit, 511, -128):
1338+
if x % candidate == 0:
1339+
return candidate
1340+
return limit
1341+
1342+
1343+
def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
1344+
g: int) -> tuple[int, int, int]:
1345+
"""
1346+
Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
1347+
(MoE) setting.
1348+
1349+
Args:
1350+
m (int): The total number of tokens.
1351+
n (int): The output feature dimension.
1352+
k (int): The input feature dimension.
1353+
g (int): The number of experts.
1354+
1355+
Returns:
1356+
tuple[int, int, int]: A tuple (tm, tk, tn)
1357+
"""
1358+
1359+
# TODO(Chengji): increase the upper limit tiling size of m when we can set
1360+
# the vmem size to be used for gmm kernel.
1361+
# NOTE: In average each expert has m // g tokens, but as it might be unbalanced,
1362+
# here we doubled the token size when choosing tiling size of m. 2m//g can be
1363+
# either greater or less than 512. If there are 32 tokens and topk=2,
1364+
# m=topk * num_tokens=64, in this case, 2*m//g will be less than 512.
1365+
tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512)
1366+
tm = min(tm, m) # there's a requirement that m % tm == 0
1367+
# k/n correspond to n_input_features/n_output_features in the matmul so they are
1368+
# normally greater than 2048, unless the num shards is large.
1369+
tk = _round_up_to_multiple_of_128_within_limit(k, 2048)
1370+
tn = _round_up_to_multiple_of_128_within_limit(n, 2048)
1371+
return tm, tk, tn
1372+
1373+
13091374
@requires_jax
13101375
def gmm(
13111376
lhs: torch.Tensor,
13121377
rhs: torch.Tensor,
13131378
group_sizes: torch.Tensor,
1314-
tiling: Tuple[int, int, int] = (512, 512, 512),
1379+
tiling: Optional[tuple[int, int, int]] = None,
13151380
group_offset: torch.Tensor | None = None,
13161381
transpose_rhs: bool = False,
13171382
) -> torch.Tensor:
13181383
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
13191384
1320-
Args:
1321-
lhs: A 2d, torch.Tensor with shape [m, k].
1322-
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
1323-
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
1324-
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
1325-
group_offset: The group in group sizes to start computing from. This is
1326-
particularly useful for when rhs num_groups is sharded.
1327-
transpose_rhs: True if the rhs needs to be transposed.
1328-
Returns:
1329-
A 2d, torch.Tensor with shape [m, n].
1330-
"""
1385+
Args:
1386+
lhs: A 2d, torch.Tensor with shape [m, k].
1387+
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
1388+
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
1389+
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
1390+
group_offset: The group in group sizes to start computing from. This is
1391+
particularly useful for when rhs num_groups is sharded.
1392+
transpose_rhs: True if the rhs needs to be transposed.
1393+
Returns:
1394+
A 2d, torch.Tensor with shape [m, n].
1395+
"""
13311396
# Import JAX within the function such that we don't need to call the jax_import_guard()
13321397
# in the global scope which could cause problems for xmp.spawn.
13331398
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm
13341399

1335-
m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2]
1336-
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
1400+
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
1401+
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]
1402+
1403+
if tiling is None:
1404+
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
1405+
else:
1406+
tm, tk, tn = tiling
1407+
tm = min(tm, m)
13371408
preferred_element_type = lhs.dtype
1338-
return xb.call_jax(gmm, (lhs, rhs, group_sizes, preferred_element_type,
1339-
(tm, tk, tn), group_offset),
1340-
{"transpose_rhs": transpose_rhs})
1409+
return xb.call_jax(
1410+
gmm,
1411+
(lhs, rhs, group_sizes, preferred_element_type,
1412+
(tm, tk, tn), group_offset),
1413+
{"transpose_rhs": transpose_rhs},
1414+
)
13411415

13421416

13431417
@requires_jax
13441418
def tgmm(
13451419
lhs: torch.Tensor,
13461420
rhs: torch.Tensor,
13471421
group_sizes: torch.Tensor,
1348-
tiling: Tuple[int, int, int] = (512, 512, 512),
1422+
tiling: tuple[int, int, int] = (512, 512, 512),
13491423
group_offset: torch.Tensor | None = None,
13501424
num_actual_groups: int | None = None,
13511425
) -> torch.Tensor:
@@ -1592,33 +1666,26 @@ def gmm_xla(
15921666
rhs: torch.Tensor,
15931667
group_sizes: torch.Tensor,
15941668
# pytorch custom op does not allow tuple type, use list instead
1595-
tiling: Optional[List[int]] = [512, 512, 512],
1669+
tiling: Optional[List[int]] = None,
15961670
group_offset: torch.Tensor | None = None,
15971671
transpose_rhs: bool = False):
1598-
if tiling is None:
1599-
tiling = [512, 512, 512]
1600-
assert len(tiling) == 3, "tiling must be a list with 3 integers"
16011672
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
16021673
assert rhs.dim(
16031674
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
1604-
tiling = tuple(tiling)
16051675
return gmm(lhs, rhs, group_sizes, tiling, group_offset, transpose_rhs)
16061676

16071677

16081678
@impl(XLA_LIB, "gmm", "CompositeExplicitAutograd")
16091679
def gmm_non_xla(lhs: torch.Tensor,
16101680
rhs: torch.Tensor,
16111681
group_sizes: torch.Tensor,
1612-
tiling: Optional[List[int]] = [512, 512, 512],
1682+
tiling: Optional[List[int]] = None,
16131683
group_offset: torch.Tensor | None = None,
16141684
transpose_rhs: bool = False):
16151685
# This will be called when dynamo use fake tensor to construct the fake output.
16161686
# We need to make sure output tensor's shape is correct.
16171687
if lhs.device != torch.device("meta"):
16181688
warnings.warn(f'XLA gmm should only be applied to tensors on XLA device')
1619-
if tiling is None:
1620-
tiling = [512, 512, 512]
1621-
assert len(tiling) == 3, "tiling must be a list with 3 integers"
16221689
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
16231690
assert rhs.dim(
16241691
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n] or [num_groups, n, k] when transpose_rhs is True"

0 commit comments

Comments
 (0)