Skip to content

Refactor: Decouple Core Transformer Blocks #1852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
702 changes: 702 additions & 0 deletions MaxText/layers/blocks.py

Large diffs are not rendered by default.

11 changes: 6 additions & 5 deletions MaxText/layers/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from MaxText.layers import attentions
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import models
from MaxText.common_types import Config
from MaxText.layers.normalizations import RMSNorm
from MaxText.layers import moe
from MaxText.layers import quantizations
from MaxText.layers.quantizations import AqtQuantization as Quant
Expand All @@ -43,7 +44,7 @@
def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, decoder_positions, deterministic, model_mode):
"""self-attention with normalization"""
# Normalization
lnx_rms = models.RMSNorm(
lnx_rms = RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="pre_self_attention_layer_norm",
Expand Down Expand Up @@ -94,7 +95,7 @@ def self_attention_with_norm(inputs, cfg, mesh, quant, decoder_segment_ids, deco
intermediate_inputs = inputs + attention_lnx

# Normalization
hidden_states = models.RMSNorm(
hidden_states = RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
Expand Down Expand Up @@ -127,7 +128,7 @@ def post_process(cfg, layer_output, sow):
class DeepSeekDenseLayer(nn.Module):
"""DeepSeek-style dense layer with Multi-Head Latent Attention."""

config: models.Config
config: Config
mesh: Mesh
quant: Optional[Quant] = None

Expand Down Expand Up @@ -177,7 +178,7 @@ class DeepSeekMoELayer(nn.Module):
Uses a bias in routing instead of load balancing loss.
"""

config: models.Config
config: Config
mesh: Mesh
quant: Optional[Quant] = None

Expand Down
21 changes: 5 additions & 16 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def __init__(
axis: Union[Iterable[int], int] = -1,
weight_dtype: DType = jnp.float32,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(
1.0, "fan_in", "truncated_normal"
),
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
quant: Optional[Quant] = None,
use_bias: bool = False,
Expand Down Expand Up @@ -127,9 +125,7 @@ def __init__(
# Parameter initialization
kernel_shape = self.in_features_shape + self.out_features_shape
kernel_in_axis = np.arange(len(self.axis))
kernel_out_axis = np.arange(
len(self.axis), len(self.axis) + len(self.out_features_shape)
)
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape))

if not quantizations.in_serve_mode(self.quant):
self.kernel = nnx.Param(
Expand Down Expand Up @@ -218,9 +214,7 @@ def dense_general(
axis: Union[Iterable[int], int] = -1,
weight_dtype: DType = jnp.float32,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(
1.0, "fan_in", "truncated_normal"
),
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
quant: Optional[Quant] = None,
use_bias: bool = False,
Expand All @@ -247,15 +241,11 @@ def dense_general(
name: name passed to the ToLinen Module
"""
if not (inputs_shape is not None) ^ (in_features_shape is not None):
raise ValueError(
"Exactly one of inputs_shape or in_features must be specified."
)
raise ValueError("Exactly one of inputs_shape or in_features must be specified.")

if inputs_shape is not None:
axis = _canonicalize_tuple(axis)
in_features_shape = tuple(
inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape))
)
in_features_shape = tuple(inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape)))
else:
assert in_features_shape is not None
module = nnx.bridge.to_linen(
Expand Down Expand Up @@ -401,4 +391,3 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):

output = checkpoint_name(output, "mlpwo")
return output

6 changes: 3 additions & 3 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from MaxText.inference import page_manager
from MaxText.layers import linears
from MaxText.layers import models
from MaxText.common_types import Config
from MaxText.layers import quantizations
from MaxText.layers.attentions import Attention
from MaxText.layers.quantizations import AqtQuantization as Quant
Expand All @@ -44,7 +44,7 @@
class LlamaDecoderLayer(nn.Module):
"""Transformer decoder layer that attends to the encoder."""

config: models.Config
config: Config
mesh: Mesh
quant: Optional[Quant] = None

Expand Down Expand Up @@ -120,7 +120,7 @@ def __call__(
intermediate_inputs = inputs + attention_lnx

# Fully Connected
hidden_states = models.RMSNorm(
hidden_states = RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
Expand Down
15 changes: 7 additions & 8 deletions MaxText/layers/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from MaxText.inference import page_manager
from MaxText.layers import initializers
from MaxText.layers import linears
from MaxText.layers import models
from MaxText.layers import moe
from MaxText.layers import quantizations
from MaxText.layers import attentions
Expand All @@ -58,7 +57,7 @@ class Llama4UnfoldConvolution(nn.Module):

def setup(self):
"""
Initialize Llama4UnfoldConvolution
Initialize Llama4UnfoldConvolution
"""
cfg = self.config
# Linear projection layer using dense_general.
Expand Down Expand Up @@ -190,7 +189,7 @@ class Llama4VisionMLP2(nn.Module):

def setup(self):
"""
Initialize Llama4VisionMLP2
Initialize Llama4VisionMLP2
"""
cfg = self.config
self.fc1 = linears.dense_general(
Expand Down Expand Up @@ -348,14 +347,14 @@ class Llama4DecoderLayer(nn.Module):
"""Transformer decoder layer for Llama4.

Attributes:
config: models.Config, MaxText model config
config: Config, MaxText model config
mesh: Mesh, JAX device mesh (used for sharding)
quant: Optional[Quant], quantization config
is_nope_layer: bool, whether to use RoPE or not on this layer
is_moe_layer: bool, whether this layer operates as a MoE layer
"""

config: models.Config
config: Config
mesh: Mesh
quant: Optional[Quant] = None
is_nope_layer: bool = False
Expand Down Expand Up @@ -446,7 +445,7 @@ def __call__(
intermediate_inputs = inputs + attention_lnx

# Fully Connected
hidden_states = models.RMSNorm(
hidden_states = RMSNorm(
dtype=cfg.dtype,
weight_dtype=cfg.weight_dtype,
name="post_self_attention_layer_norm",
Expand Down Expand Up @@ -518,15 +517,15 @@ class Llama4ScannableBlock(nn.Module):
A repeatable block given nope_layer_interval and interleave_moe_layer_step

Attributes:
config: models.Config, MaxText model config
config: Config, MaxText model config
mesh: Mesh, JAX device mesh (used for sharding)
quant: Optional[Quant], quantization config
nope_layer_interval: int, the interval at which layers should use NoPE.
interleave_moe_layer_step: int, the interval or stride for placing MoE layers.
"""
'''

config: models.Config
config: Config
mesh: Mesh
quant: Optional[Quant] = None
nope_layer_interval: int = 1
Expand Down
Loading
Loading