Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 94 additions & 27 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,46 +1306,120 @@ def repeat_with_fixed_output_size(input: torch.Tensor, repeats: torch.Tensor,
return res


def _round_up_to_multiple_of_128_within_limit(x: int, limit: int) -> int:
"""
Rounds the given integer `x` up to the nearest multiple of 128, without exceeding
the specified `limit`.

If `x` is less than or equal to 128, returns 128.
If `x` is less than `limit`, returns the smallest multiple of 128 greater than or
equal to `x`.
If `x` is greater than or equal to `limit`, searches for the largest multiple of
128 less than or equal to `limit` (down to 512) that divides `x` evenly, and
returns it.
If no such candidate is found, returns `limit`.

Args:
x (int): The integer to round up.
limit (int): The upper bound (must be a multiple of 128 and at least 128).

Returns:
int: The rounded value according to the rules above.

Raises:
AssertionError: If `limit` is less than 128 or not a multiple of 128.
"""
assert limit >= 128 and limit % 128 == 0
if x <= 128:
return 128
if x < limit:
return (x + 127) // 128 * 128
for candidate in range(limit, 511, -128):
if x % candidate == 0:
return candidate
return limit


def _get_tiling_size_for_gmm_kernel(m: int, k: int, n: int,
g: int) -> tuple[int, int, int]:
"""
Calculate optimal tiling sizes for a GMM kernel in a Mixture of Experts
(MoE) setting.

Args:
m (int): The total number of tokens.
n (int): The output feature dimension.
k (int): The input feature dimension.
g (int): The number of experts.

Returns:
tuple[int, int, int]: A tuple (tm, tk, tn)
"""

# TODO(Chengji): increase the upper limit tiling size of m when we can set
# the vmem size to be used for gmm kernel.
# NOTE: In average each expert has m // g tokens, but as it might be unbalanced,
# here we doubled the token size when choosing tiling size of m. 2m//g can be
# either greater or less than 512. If there are 32 tokens and topk=2,
# m=topk * num_tokens=64, in this case, 2*m//g will be less than 512.
tm = _round_up_to_multiple_of_128_within_limit(2 * m // g, 512)
tm = min(tm, m) # there's a requirement that m % tm == 0
# k/n correspond to n_input_features/n_output_features in the matmul so they are
# normally greater than 2048, unless the num shards is large.
tk = _round_up_to_multiple_of_128_within_limit(k, 2048)
tn = _round_up_to_multiple_of_128_within_limit(n, 2048)
return tm, tk, tn


@requires_jax
def gmm(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: Tuple[int, int, int] = (512, 512, 512),
tiling: Optional[tuple[int, int, int]] = None,
group_offset: torch.Tensor | None = None,
transpose_rhs: bool = False,
) -> torch.Tensor:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.

Args:
lhs: A 2d, torch.Tensor with shape [m, k].
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
transpose_rhs: True if the rhs needs to be transposed.
Returns:
A 2d, torch.Tensor with shape [m, n].
"""
Args:
lhs: A 2d, torch.Tensor with shape [m, k].
rhs: A 3d, torch.Tensor with shape [num_groups, k, n].
group_sizes: A 1d, torch.Tensor with shape [num_groups] and torch.int32 dtype.
tiling: 3-tuple of ints. The m, k and n-dimension tile sizes.
group_offset: The group in group sizes to start computing from. This is
particularly useful for when rhs num_groups is sharded.
transpose_rhs: True if the rhs needs to be transposed.
Returns:
A 2d, torch.Tensor with shape [m, n].
"""
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
from jax.experimental.pallas.ops.tpu.megablox.gmm import gmm

m, k, n = lhs.shape[0], lhs.shape[1], rhs.shape[2]
tm, tk, tn = min(tiling[0], m), min(tiling[1], k), min(tiling[2], n)
m, k, g = lhs.shape[0], lhs.shape[1], rhs.shape[0]
n = rhs.shape[1] if transpose_rhs else rhs.shape[2]

if tiling is None:
tm, tk, tn = _get_tiling_size_for_gmm_kernel(m, k, n, g)
else:
tm, tk, tn = tiling
tm = min(tm, m)
preferred_element_type = lhs.dtype
return xb.call_jax(gmm, (lhs, rhs, group_sizes, preferred_element_type,
(tm, tk, tn), group_offset),
{"transpose_rhs": transpose_rhs})
return xb.call_jax(
gmm,
(lhs, rhs, group_sizes, preferred_element_type,
(tm, tk, tn), group_offset),
{"transpose_rhs": transpose_rhs},
)


@requires_jax
def tgmm(
lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: Tuple[int, int, int] = (512, 512, 512),
tiling: tuple[int, int, int] = (512, 512, 512),
group_offset: torch.Tensor | None = None,
num_actual_groups: int | None = None,
) -> torch.Tensor:
Expand Down Expand Up @@ -1592,33 +1666,26 @@ def gmm_xla(
rhs: torch.Tensor,
group_sizes: torch.Tensor,
# pytorch custom op does not allow tuple type, use list instead
tiling: Optional[List[int]] = [512, 512, 512],
tiling: Optional[List[int]] = None,
group_offset: torch.Tensor | None = None,
transpose_rhs: bool = False):
if tiling is None:
tiling = [512, 512, 512]
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n]"
tiling = tuple(tiling)
return gmm(lhs, rhs, group_sizes, tiling, group_offset, transpose_rhs)


@impl(XLA_LIB, "gmm", "CompositeExplicitAutograd")
def gmm_non_xla(lhs: torch.Tensor,
rhs: torch.Tensor,
group_sizes: torch.Tensor,
tiling: Optional[List[int]] = [512, 512, 512],
tiling: Optional[List[int]] = None,
group_offset: torch.Tensor | None = None,
transpose_rhs: bool = False):
# This will be called when dynamo use fake tensor to construct the fake output.
# We need to make sure output tensor's shape is correct.
if lhs.device != torch.device("meta"):
warnings.warn(f'XLA gmm should only be applied to tensors on XLA device')
if tiling is None:
tiling = [512, 512, 512]
assert len(tiling) == 3, "tiling must be a list with 3 integers"
assert lhs.dim() == 2, "lhs must be a 2d, torch.Tensor with shape [k, m]"
assert rhs.dim(
) == 3, "rhs must be a A 3d torch.Tensor with shape [num_groups, k, n] or [num_groups, n, k] when transpose_rhs is True"
Expand Down