Skip to content

Remove noop spmd_mode check, correct type annotations in attention flax #170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ def _maybe_aqt_einsum(quant: Quant):
class AttentionOp(nn.Module):
mesh: Mesh
attention_kernel: str
scale: int
scale: float
heads: int
dim_head: int
use_memory_efficient_attention: bool = False
split_head_dim: bool = False
float32_qk_product: bool = True
flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV)
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
flash_block_sizes: BlockSizes | None = None
dtype: DType = jnp.float32
quant: Quant = None
quant: Quant | None = None

def setup(self):
if self.attention_kernel == "cudnn_flash_te":
Expand All @@ -79,7 +79,7 @@ def setup(self):
dtype=self.dtype,
# float32_logits=self.float32_logits,
qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=self.scale,
scale_factor=float(self.scale),
transpose_batch_sequence=False,
)

Expand Down Expand Up @@ -415,15 +415,15 @@ class FlaxFluxAttention(nn.Module):
split_head_dim: bool = False
attention_kernel: str = "dot_product"
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
flash_block_sizes: BlockSizes | None = None
mesh: jax.sharding.Mesh | None = None
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
out_axis_names: AxisNames = (BATCH, LENGTH, EMBED)
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None
qkv_bias: bool = False

def setup(self):
Expand Down Expand Up @@ -619,16 +619,16 @@ class FlaxAttention(nn.Module):
split_head_dim: bool = False
attention_kernel: str = "dot_product"
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
flash_block_sizes: BlockSizes | None = None
mesh: jax.sharding.Mesh | None = None
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
query_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
key_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
value_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
out_axis_names: AxisNames = (BATCH, LENGTH, HEAD)
precision: jax.lax.Precision = None
quant: Quant = None
precision: jax.lax.Precision | None = None
quant: Quant | None = None

def setup(self):

Expand Down Expand Up @@ -762,10 +762,10 @@ class FlaxBasicTransformerBlock(nn.Module):
split_head_dim: bool = False
attention_kernel: str = "dot_product"
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
precision: jax.lax.Precision = None
quant: Quant = None
flash_block_sizes: BlockSizes | None = None
mesh: jax.sharding.Mesh | None = None
precision: jax.lax.Precision | None = None
quant: Quant | None = None

def setup(self):
# self attention (or cross_attention if only_cross_attention is True)
Expand Down Expand Up @@ -890,12 +890,12 @@ class FlaxTransformer2DModel(nn.Module):
split_head_dim: bool = False
attention_kernel: str = "dot_product"
flash_min_seq_length: int = 4096
flash_block_sizes: BlockSizes = None
mesh: jax.sharding.Mesh = None
flash_block_sizes: BlockSizes | None = None
mesh: jax.sharding.Mesh | None = None
norm_num_groups: int = 32
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None
hidden_state_axis_names: AxisNames = (BATCH, LENGTH, D_KV)
quant: Quant = (None,)
quant: Quant | tuple[None] = (None,)

def setup(self):
self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype)
Expand Down Expand Up @@ -1019,7 +1019,7 @@ class FlaxFeedForward(nn.Module):
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None

def setup(self):
# The second linear layer needs to be called
Expand Down Expand Up @@ -1051,7 +1051,7 @@ class FlaxGEGLU(nn.Module):
dropout: float = 0.0
dtype: jnp.dtype = jnp.float32
weights_dtype: jnp.dtype = jnp.float32
precision: jax.lax.Precision = None
precision: jax.lax.Precision | None = None

def setup(self):
inner_dim = self.dim * 4
Expand Down
43 changes: 21 additions & 22 deletions src/maxdiffusion/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@


def get_first_step(state):
with jax.spmd_mode("allow_all"):
return int(state.step)
return int(state.step)


def load_next_batch(train_iter, example_batch, config):
Expand Down Expand Up @@ -101,27 +100,27 @@ def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step

def write_metrics_to_tensorboard(writer, metrics, step, config):
"""Writes metrics to tensorboard"""
with jax.spmd_mode("allow_all"):
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0
if jax.process_index() == 0:
max_logging.log(
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
step,
metrics["scalar"]["perf/step_time_seconds"],
metrics["scalar"]["perf/per_device_tflops_per_sec"],
float(metrics["scalar"]["learning/loss"]),
)
)

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
writer.flush()
if jax.process_index() == 0:
for metric_name in metrics.get("scalar", []):
writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step)
for metric_name in metrics.get("scalars", []):
writer.add_scalars(metric_name, metrics["scalars"][metric_name], step)

full_log = step % config.log_period == 0
if jax.process_index() == 0:
max_logging.log(
"completed step: {}, seconds: {:.3f}, TFLOP/s/device: {:.3f}, loss: {:.3f}".format(
step,
metrics["scalar"]["perf/step_time_seconds"],
metrics["scalar"]["perf/per_device_tflops_per_sec"],
float(metrics["scalar"]["learning/loss"]),
)
)

if full_log and jax.process_index() == 0:
max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'")
writer.flush()


def get_params_to_save(params):
Expand Down
3 changes: 2 additions & 1 deletion src/maxdiffusion/utils/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from collections import OrderedDict
from dataclasses import fields, is_dataclass
from dataclasses import dataclass, fields, is_dataclass
from typing import Any, Tuple

import numpy as np
Expand All @@ -37,6 +37,7 @@ def is_tensor(x):
return isinstance(x, np.ndarray)


@dataclass
class BaseOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
Expand Down
12 changes: 6 additions & 6 deletions src/maxdiffusion/utils/pil_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
}
else:
PIL_INTERPOLATION = {
"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
"nearest": PIL.Image.NEAREST,
"linear": PIL.Image.LINEAR, # pytype: disable=module-attr
"bilinear": PIL.Image.BILINEAR, # pytype: disable=module-attr
"bicubic": PIL.Image.BICUBIC, # pytype: disable=module-attr
"lanczos": PIL.Image.LANCZOS, # pytype: disable=module-attr
"nearest": PIL.Image.NEAREST, # pytype: disable=module-attr
}


Expand Down Expand Up @@ -50,7 +50,7 @@ def numpy_to_pil(images):
return pil_images


def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int = None) -> PIL.Image.Image:
def make_image_grid(images: List[PIL.Image.Image], rows: int, cols: int, resize: int | None = None) -> PIL.Image.Image:
"""
Prepares a single grid of images. Useful for visualization purposes.
"""
Expand Down
Loading