Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/MaxText/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _maybe_shard_with_name(inputs, sharding_names):
return maybe_shard_with_name(inputs, sharding_names, config.shard_mode)

# For more efficient DP/ZeRO-1 + GA
if config.shard_mode == ShardMode.EXPLICIT and config.ici_data_parallelism > 1:
if config.shard_mode == ShardMode.EXPLICIT and model.mesh.shape.get("data", 1) > 1:
ga_params_shardings = jax.tree.map(update_sharding_for_reduced, params_shardings)
grad_shardings = jax.tree.map(update_sharding_for_unreduced, params_shardings)
else:
Expand Down
6 changes: 4 additions & 2 deletions src/MaxText/kernels/megablox/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _calculate_bytes(x: jax.Array | qpl.QArray) -> int:
"tiling",
"transpose_rhs",
"interpret",
"vma_axes",
],
)
def gmm(
Expand All @@ -310,6 +311,7 @@ def gmm(
existing_out: jnp.ndarray | None = None,
transpose_rhs: bool = False,
interpret: bool = False,
vma_axes: tuple = tuple(),
) -> jnp.ndarray:
"""Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'.

Expand Down Expand Up @@ -522,7 +524,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
}
call_gmm = qpl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type, vma=set(vma_axes)),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
Expand Down Expand Up @@ -775,7 +777,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
}
call_gmm = qpl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type, vma=set()),
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=2,
in_specs=[
Expand Down
16 changes: 10 additions & 6 deletions src/MaxText/kernels/megablox/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def gmm(
use_qwix_quantization: bool = False,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
vma_axes: tuple = tuple(),
):
"""Grouped matrix multiplication operation."""
quantization_rule = None
Expand All @@ -60,9 +61,11 @@ def gmm(
act_calibration_method="absmax",
)

gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
_gmm_fwd_vma = functools.partial(_gmm_fwd, vma_axes=vma_axes)
_gmm_bwd_vma = functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype, vma_axes=vma_axes)
gmm_fwd_bwd = lambda *args: _gmm_fwd_vma(*args)[0] # pylint: disable=C3001
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
gmm_fwd_bwd.defvjp(_gmm_fwd_vma, _gmm_bwd_vma)
return gmm_fwd_bwd(
lhs,
rhs,
Expand Down Expand Up @@ -92,6 +95,7 @@ def _gmm_fwd(
quantization_rule: qwix.QtRule | None = None,
use_tokamax_backend: bool = False,
weight_gather_axes: List[Tuple[str, int]] | None = None,
vma_axes: tuple = tuple(),
) -> tuple[
jnp.ndarray,
tuple[
Expand Down Expand Up @@ -125,10 +129,7 @@ def _gmm_fwd(
# QAG is only supported for following conditions
if use_tokamax_backend:
if quantization_rule and quantization_rule.bwd_qtype:
if (
quantization_rule.weight_calibration_method.startswith("fixed")
and isinstance(rhs, qpl.QArray)
):
if quantization_rule.weight_calibration_method.startswith("fixed") and isinstance(rhs, qpl.QArray):
if weight_gather_axes:
for axis_name, axis_idx in weight_gather_axes:
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, axis_name, axis=axis_idx, tiled=True)
Expand All @@ -155,6 +156,7 @@ def _gmm_fwd(
existing_out,
transpose_rhs=transpose_rhs,
interpret=interpret,
vma_axes=vma_axes,
)
return out, (lhs, rhs, group_sizes, group_offset)

Expand All @@ -176,6 +178,7 @@ def _gmm_bwd(
jnp.ndarray | None,
],
grad: jnp.ndarray,
vma_axes: tuple = tuple(),
) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]:
"""Backward function for throughput GMM VJP."""
del preferred_element_type
Expand Down Expand Up @@ -256,6 +259,7 @@ def _gmm_bwd(
group_offset,
transpose_rhs=not transpose_rhs,
interpret=interpret,
vma_axes=vma_axes,
)
drhs = backend.tgmm(
lhs.swapaxes(0, 1),
Expand Down
6 changes: 5 additions & 1 deletion src/MaxText/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,9 @@ def sparse_matmul(
):
"""Perform sparse matrix multiplication of inputs and Experts."""

vma_axes = tuple(axis for axis in self.config.mesh_axes if self.mesh.shape[axis] > 1)
use_vma = not self.config.use_tokamax_gmm

def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_axes):
pad_length = self.config.wi_tile_fwd_batch_seq
hs_shape = inputs.shape
Expand Down Expand Up @@ -882,6 +885,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
use_qwix_quantization=self.config.use_qwix_quantization,
use_tokamax_backend=self.config.use_tokamax_gmm,
weight_gather_axes=weight_gather_axes,
vma_axes=vma_axes,
)
else:
rhs_inputs = kernel
Expand Down Expand Up @@ -1006,7 +1010,7 @@ def gmm(inputs, kernel, tiling, group_sizes, expert_assignments, weight_gather_a
None,
),
out_specs=(self._logical_to_mesh_axes((batch_logical_axis, "activation_norm_length", "activation_embed"))),
check_vma=False,
check_vma=use_vma,
)
def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, rngs):
batch_size, sequence_length, _ = x.shape
Expand Down
6 changes: 3 additions & 3 deletions src/MaxText/layers/normalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def __call__(self, x: jnp.ndarray, out_sharding: NamedSharding | None = None) ->
effective_scale = scale + self.scale_offset # Apply offset
# y: (B, S, E)
# effective_scale: (E,) -> (1, 1, E) -> (B, S, E)
effective_scale = jnp.expand_dims(effective_scale, axis=tuple(range(y.ndim - effective_scale.ndim)))
effective_scale = jnp.broadcast_to(effective_scale, y.shape, out_sharding=out_sharding)
return jnp.multiply(y, effective_scale)
# effective_scale = jnp.expand_dims(effective_scale, axis=tuple(range(y.ndim - effective_scale.ndim)))
# effective_scale = jnp.broadcast_to(effective_scale, y.shape, out_sharding=out_sharding)
return jnp.einsum("i...k,...k->i...k", y, effective_scale)


def Qwen3NextRMSNorm(num_features: int, eps: float, dtype: DType, weight_dtype: DType, *, rngs: nnx.Rngs):
Expand Down
Loading