Skip to content

Get Linter / CPU tests to succeed #1844

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions MaxText/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def create_orbax_emergency_checkpoint_manager(

# Only create directories if running on GPUs as the previous
# directory structure might be assumed by TPUs
if global_mesh.devices.flatten()[0].platform == 'gpu':
if global_mesh.devices.flatten()[0].platform == "gpu":
# pylint: disable=protected-access
local_checkpoint_dir = f"{local_checkpoint_dir}/{jax._src.distributed.global_state.process_id}"
local_p = epath.Path(local_checkpoint_dir)
Expand Down Expand Up @@ -374,9 +374,7 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute-
max_logging.log("Setting up checkpoint logger...")
if config.enable_checkpoint_cloud_logger:
logger_name = f"goodput_{config.run_name}"
options = ocp.logging.CloudLoggerOptions(
job_name=config.run_name, logger_name=logger_name
)
options = ocp.logging.CloudLoggerOptions(job_name=config.run_name, logger_name=logger_name)
orbax_cloud_logger = ocp.logging.CloudLogger(options=options)
max_logging.log("Successfully set up checkpoint cloud logger.")
return orbax_cloud_logger
Expand Down
13 changes: 7 additions & 6 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,12 @@ def main(argv: Sequence[str]) -> None:
for i in range(_NUM_STREAMS):
with jax.profiler.StepTraceAnnotation("prefill", stream=i):
prefill_result, first_token = engine.prefill(
params=params,
padded_tokens=tokens,
images=processor_output.pixel_values,
true_length=true_length,
rng=rng_prefill,
slot=i,
params=params,
padded_tokens=tokens,
images=processor_output.pixel_values,
true_length=true_length,
rng=rng_prefill,
slot=i,
)
prefill_result_list.append(prefill_result)
first_token_list.append(first_token)
Expand Down Expand Up @@ -178,6 +178,7 @@ def main(argv: Sequence[str]) -> None:
# Deactivate profiler
prof.deactivate()


def _validate_config(config):
assert config.load_full_state_path == "", (
"Decode doesn't operate on full states! Convert to parameter checkpoint first." "Using generate_param_only_checkpoint."
Expand Down
21 changes: 5 additions & 16 deletions MaxText/layers/linears.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def __init__(
axis: Union[Iterable[int], int] = -1,
weight_dtype: DType = jnp.float32,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(
1.0, "fan_in", "truncated_normal"
),
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
quant: Optional[Quant] = None,
use_bias: bool = False,
Expand Down Expand Up @@ -127,9 +125,7 @@ def __init__(
# Parameter initialization
kernel_shape = self.in_features_shape + self.out_features_shape
kernel_in_axis = np.arange(len(self.axis))
kernel_out_axis = np.arange(
len(self.axis), len(self.axis) + len(self.out_features_shape)
)
kernel_out_axis = np.arange(len(self.axis), len(self.axis) + len(self.out_features_shape))

if not quantizations.in_serve_mode(self.quant):
self.kernel = nnx.Param(
Expand Down Expand Up @@ -218,9 +214,7 @@ def dense_general(
axis: Union[Iterable[int], int] = -1,
weight_dtype: DType = jnp.float32,
dtype: DType = jnp.float32,
kernel_init: NdInitializer = nd_dense_init(
1.0, "fan_in", "truncated_normal"
),
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal"),
kernel_axes: Tuple[Optional[str], ...] = (),
quant: Optional[Quant] = None,
use_bias: bool = False,
Expand All @@ -247,15 +241,11 @@ def dense_general(
name: name passed to the ToLinen Module
"""
if not (inputs_shape is not None) ^ (in_features_shape is not None):
raise ValueError(
"Exactly one of inputs_shape or in_features must be specified."
)
raise ValueError("Exactly one of inputs_shape or in_features must be specified.")

if inputs_shape is not None:
axis = _canonicalize_tuple(axis)
in_features_shape = tuple(
inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape))
)
in_features_shape = tuple(inputs_shape[ax] for ax in _normalize_axes(axis, len(inputs_shape)))
else:
assert in_features_shape is not None
module = nnx.bridge.to_linen(
Expand Down Expand Up @@ -401,4 +391,3 @@ def __call__(self, inputs, decode: bool = False, deterministic: bool = False):

output = checkpoint_name(output, "mlpwo")
return output

4 changes: 2 additions & 2 deletions MaxText/layers/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Llama4UnfoldConvolution(nn.Module):

def setup(self):
"""
Initialize Llama4UnfoldConvolution
Initialize Llama4UnfoldConvolution
"""
cfg = self.config
# Linear projection layer using dense_general.
Expand Down Expand Up @@ -190,7 +190,7 @@ class Llama4VisionMLP2(nn.Module):

def setup(self):
"""
Initialize Llama4VisionMLP2
Initialize Llama4VisionMLP2
"""
cfg = self.config
self.fc1 = linears.dense_general(
Expand Down
5 changes: 1 addition & 4 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,7 @@ def __call__(
inputs_shape=y.shape,
out_features_shape=cfg.vocab_size,
weight_dtype=cfg.weight_dtype,
dtype=jnp.float32
if cfg.logits_dot_in_fp32
else cfg.dtype, # for logit training stability
dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability
kernel_axes=("embed", "vocab"),
name="logits_dense",
matmul_precision=self.config.matmul_precision,
Expand Down Expand Up @@ -804,4 +802,3 @@ def __call__(
image_embeddings=image_embeddings,
)
return logits

1 change: 0 additions & 1 deletion MaxText/layers/multi_token_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,3 @@ def __call__(
# Shape: [B, S, H]
# --- Return Processed Hidden State ---
return next_hidden_state

6 changes: 1 addition & 5 deletions MaxText/maxengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,7 @@ def _layout(x, s, l):
return x
# Somehow this can be None sometimes.
dll = l.device_local_layout if isinstance(l, Format) else l
f = (
jax.jit(self._identity, out_shardings=Format(dll, s))
.lower(x)
.compile(compiler_options=xla_flags)
)
f = jax.jit(self._identity, out_shardings=Format(dll, s)).lower(x).compile(compiler_options=xla_flags)
y = f(x)
# Achieves donation of the input argument, but allows for different memory
# layouts and shapes.
Expand Down
1 change: 1 addition & 0 deletions MaxText/tests/hf_checkpoint_conversion_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import torch
import torch.nn.functional as F
Expand Down
Loading