@@ -1306,46 +1306,120 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
1306
1306
return res
1307
1307
1308
1308
1309
+ def _round_up_to_multiple_of_128_within_limit (x : int , limit : int ) -> int :
1310
+ """
1311
+ Rounds the given integer `x` up to the nearest multiple of 128, without exceeding
1312
+ the specified `limit`.
1313
+
1314
+ If `x` is less than or equal to 128, returns 128.
1315
+ If `x` is less than `limit`, returns the smallest multiple of 128 greater than or
1316
+ equal to `x`.
1317
+ If `x` is greater than or equal to `limit`, searches for the largest multiple of
1318
+ 128 less than or equal to `limit` (down to 512) that divides `x` evenly, and
1319
+ returns it.
1320
+ If no such candidate is found, returns `limit`.
1321
+
1322
+ Args:
1323
+ x (int): The integer to round up.
1324
+ limit (int): The upper bound (must be a multiple of 128 and at least 128).
1325
+
1326
+ Returns:
1327
+ int: The rounded value according to the rules above.
1328
+
1329
+ Raises:
1330
+ AssertionError: If `limit` is less than 128 or not a multiple of 128.
1331
+ """
1332
+ assert limit >= 128 and limit % 128 == 0
1333
+ if x <= 128 :
1334
+ return 128
1335
+ if x < limit :
1336
+ return (x + 127 ) // 128 * 128
1337
+ for candidate in range (limit , 511 , - 128 ):
1338
+ if x % candidate == 0 :
1339
+ return candidate
1340
+ return limit
1341
+
1342
+
1343
+ def _get_tiling_size_for_gmm_kernel (m : int , k : int , n : int ,
1344
+ g : int ) -> tuple [int , int , int ]:
1345
+ """
1346
+ Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
1347
+ (MoE) setting.
1348
+
1349
+ Args:
1350
+ m (int): The total number of tokens.
1351
+ n (int): The output feature dimension.
1352
+ k (int): The input feature dimension.
1353
+ g (int): The number of experts.
1354
+
1355
+ Returns:
1356
+ tuple[int, int, int]: A tuple (tm, tk, tn)
1357
+ """
1358
+
1359
+ # TODO(Chengji): increase the upper limit tiling size of m when we can set
1360
+ # the vmem size to be used for gmm kernel.
1361
+ # NOTE: In average each expert has m // g tokens, but as it might be unbalanced,
1362
+ # here we doubled the token size when choosing tiling size of m. 2m//g can be
1363
+ # either greater or less than 512. If there are 32 tokens and topk=2,
1364
+ # m=topk * num_tokens=64, in this case, 2*m//g will be less than 512.
1365
+ tm = _round_up_to_multiple_of_128_within_limit (2 * m // g , 512 )
1366
+ tm = min (tm , m ) # there's a requirement that m % tm == 0
1367
+ # k/n correspond to n_input_features/n_output_features in the matmul so they are
1368
+ # normally greater than 2048, unless the num shards is large.
1369
+ tk = _round_up_to_multiple_of_128_within_limit (k , 2048 )
1370
+ tn = _round_up_to_multiple_of_128_within_limit (n , 2048 )
1371
+ return tm , tk , tn
1372
+
1373
+
1309
1374
@requires_jax
1310
1375
def gmm (
1311
1376
lhs : torch .Tensor ,
1312
1377
rhs : torch .Tensor ,
1313
1378
group_sizes : torch .Tensor ,
1314
- tiling : Tuple [ int , int , int ] = ( 512 , 512 , 512 ) ,
1379
+ tiling : Optional [ tuple [ int , int , int ]] = None ,
1315
1380
group_offset : torch .Tensor | None = None ,
1316
1381
transpose_rhs : bool = False ,
1317
1382
) -> torch .Tensor :
1318
1383
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.
1319
1384
1320
- Args:
1321
- lhs: A 2d, torch.Tensor with shape [m, k].
1322
- rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
1323
- group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
1324
- tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
1325
- group_offset: The group in group sizes to start computing from. This is
1326
- particularly useful for when rhs num_groups is sharded.
1327
- transpose_rhs: True if the rhs needs to be transposed.
1328
- Returns:
1329
- A 2d, torch.Tensor with shape [m, n].
1330
- """
1385
+ Args:
1386
+ lhs: A 2d, torch.Tensor with shape [m, k].
1387
+ rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
1388
+ group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
1389
+ tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
1390
+ group_offset: The group in group sizes to start computing from. This is
1391
+ particularly useful for when rhs num_groups is sharded.
1392
+ transpose_rhs: True if the rhs needs to be transposed.
1393
+ Returns:
1394
+ A 2d, torch.Tensor with shape [m, n].
1395
+ """
1331
1396
# Import JAX within the function such that we don't need to call the jax_import_guard()
1332
1397
# in the global scope which could cause problems for xmp.spawn.
1333
1398
from jax .experimental .pallas .ops .tpu .megablox .gmm import gmm
1334
1399
1335
- m , k , n = lhs .shape [0 ], lhs .shape [1 ], rhs .shape [2 ]
1336
- tm , tk , tn = min (tiling [0 ], m ), min (tiling [1 ], k ), min (tiling [2 ], n )
1400
+ m , k , g = lhs .shape [0 ], lhs .shape [1 ], rhs .shape [0 ]
1401
+ n = rhs .shape [1 ] if transpose_rhs else rhs .shape [2 ]
1402
+
1403
+ if tiling is None :
1404
+ tm , tk , tn = _get_tiling_size_for_gmm_kernel (m , k , n , g )
1405
+ else :
1406
+ tm , tk , tn = tiling
1407
+ tm = min (tm , m )
1337
1408
preferred_element_type = lhs .dtype
1338
- return xb .call_jax (gmm , (lhs , rhs , group_sizes , preferred_element_type ,
1339
- (tm , tk , tn ), group_offset ),
1340
- {"transpose_rhs" : transpose_rhs })
1409
+ return xb .call_jax (
1410
+ gmm ,
1411
+ (lhs , rhs , group_sizes , preferred_element_type ,
1412
+ (tm , tk , tn ), group_offset ),
1413
+ {"transpose_rhs" : transpose_rhs },
1414
+ )
1341
1415
1342
1416
1343
1417
@requires_jax
1344
1418
def tgmm (
1345
1419
lhs : torch .Tensor ,
1346
1420
rhs : torch .Tensor ,
1347
1421
group_sizes : torch .Tensor ,
1348
- tiling : Tuple [int , int , int ] = (512 , 512 , 512 ),
1422
+ tiling : tuple [int , int , int ] = (512 , 512 , 512 ),
1349
1423
group_offset : torch .Tensor | None = None ,
1350
1424
num_actual_groups : int | None = None ,
1351
1425
) -> torch .Tensor :
@@ -1592,33 +1666,26 @@ def gmm_xla(
1592
1666
rhs : torch .Tensor ,
1593
1667
group_sizes : torch .Tensor ,
1594
1668
# pytorch custom op does not allow tuple type, use list instead
1595
- tiling : Optional [List [int ]] = [ 512 , 512 , 512 ] ,
1669
+ tiling : Optional [List [int ]] = None ,
1596
1670
group_offset : torch .Tensor | None = None ,
1597
1671
transpose_rhs : bool = False ):
1598
- if tiling is None :
1599
- tiling = [512 , 512 , 512 ]
1600
- assert len (tiling ) == 3 , "tiling must be a list with 3 integers"
1601
1672
assert lhs .dim () == 2 , "lhs must be a 2d, torch.Tensor with shape [k, m]"
1602
1673
assert rhs .dim (
1603
1674
) == 3 , "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
1604
- tiling = tuple (tiling )
1605
1675
return gmm (lhs , rhs , group_sizes , tiling , group_offset , transpose_rhs )
1606
1676
1607
1677
1608
1678
@impl (XLA_LIB , "gmm" , "CompositeExplicitAutograd" )
1609
1679
def gmm_non_xla (lhs : torch .Tensor ,
1610
1680
rhs : torch .Tensor ,
1611
1681
group_sizes : torch .Tensor ,
1612
- tiling : Optional [List [int ]] = [ 512 , 512 , 512 ] ,
1682
+ tiling : Optional [List [int ]] = None ,
1613
1683
group_offset : torch .Tensor | None = None ,
1614
1684
transpose_rhs : bool = False ):
1615
1685
# This will be called when dynamo use fake tensor to construct the fake output.
1616
1686
# We need to make sure output tensor's shape is correct.
1617
1687
if lhs .device != torch .device ("meta" ):
1618
1688
warnings .warn (f'XLA gmm should only be applied to tensors on XLA device' )
1619
- if tiling is None :
1620
- tiling = [512 , 512 , 512 ]
1621
- assert len (tiling ) == 3 , "tiling must be a list with 3 integers"
1622
1689
assert lhs .dim () == 2 , "lhs must be a 2d, torch.Tensor with shape [k, m]"
1623
1690
assert rhs .dim (
1624
1691
) == 3 , "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n] or [num_groups, n, k] when transpose_rhs is True"
0 commit comments