Skip to content

Commit c9229c3

Browse files
authored
save memory by deleting params. lint. (#277)
Deletes params pytree after creating TrainState to reduce HBM usage. Lints the code. The only change is in this line: https://github.com/AI-Hypercomputer/maxdiffusion/compare/deallocate_params_tree?expand=1#diff-69cda939e98b489aca1cc8aa543ecc537d9f4bc6c58fa767cbe01cd3636aacf3R311
1 parent 88611cf commit c9229c3

File tree

7 files changed

+153
-121
lines changed

7 files changed

+153
-121
lines changed

src/maxdiffusion/max_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,12 +489,14 @@ def get_precision(config):
489489
retval = jax.lax.Precision.HIGHEST
490490
return retval
491491

492+
492493
def value_or_none(flash_block_sizes, key):
493494
if key in flash_block_sizes:
494495
return flash_block_sizes[key]
495496
else:
496497
return None
497498

499+
498500
def get_flash_block_sizes(config):
499501
"""Create custom flash attention BlockSizes."""
500502
flash_block_sizes = None
@@ -508,7 +510,7 @@ def get_flash_block_sizes(config):
508510
block_kv_dkv_compute=config.flash_block_sizes["block_kv_dkv_compute"],
509511
block_q_dq=value_or_none(config.flash_block_sizes, "block_q_dq"),
510512
block_kv_dq=value_or_none(config.flash_block_sizes, "block_kv_dq"),
511-
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel")
513+
use_fused_bwd_kernel=value_or_none(config.flash_block_sizes, "use_fused_bwd_kernel"),
512514
)
513515
return flash_block_sizes
514516

@@ -528,6 +530,20 @@ def get_memory_allocations():
528530
)
529531

530532

533+
def get_live_arrays():
534+
535+
backend = jax.extend.backend.get_backend()
536+
live_arrays = backend.live_arrays()
537+
538+
max_logging.log(f"Total live arrays: {len(live_arrays)}\n")
539+
540+
for i, arr in enumerate(live_arrays):
541+
max_logging.log(f"Array {i}:")
542+
max_logging.log(f" Shape: {arr.shape}")
543+
max_logging.log(f" Dtype: {arr.dtype}")
544+
max_logging.log(f" Devices: {arr.devices()}")
545+
546+
531547
# Taking inspiration from flax's https://flax.readthedocs.io/en/v0.5.3/_modules/flax/linen/summary.html#tabulate
532548
# to retrieve layer parameters and calculate
533549
def calculate_model_tflops(module: module_lib.Module, rngs: Union[PRNGKey, RNGSequences], train, **kwargs):

src/maxdiffusion/models/attention_flax.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,14 @@ def _tpu_flash_attention(
215215
def wrap_flash_attention(query, key, value):
216216

217217
uses_fused_kernel = block_sizes.use_fused_bwd_kernel
218-
block_q_sizes = (block_sizes.block_q, block_sizes.block_q_dkv,)
219-
block_kv_sizes = (block_sizes.block_kv, block_sizes.block_kv_dkv,)
218+
block_q_sizes = (
219+
block_sizes.block_q,
220+
block_sizes.block_q_dkv,
221+
)
222+
block_kv_sizes = (
223+
block_sizes.block_kv,
224+
block_sizes.block_kv_dkv,
225+
)
220226
if uses_fused_kernel:
221227
block_q_sizes += (block_sizes.block_q_dkv,)
222228
block_kv_sizes += (block_sizes.block_kv_dkv,)
@@ -455,7 +461,16 @@ def _apply_attention(
455461
)
456462
elif attention_kernel == "flash":
457463
return _tpu_flash_attention(
458-
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, residual_checkpoint_name=residual_checkpoint_name
464+
query,
465+
key * scale,
466+
value,
467+
heads,
468+
mesh,
469+
axis_names_q,
470+
axis_names_kv,
471+
flash_block_sizes,
472+
dtype,
473+
residual_checkpoint_name=residual_checkpoint_name,
459474
)
460475
elif attention_kernel == "ring":
461476
return _tpu_flash_attention(

src/maxdiffusion/models/gradient_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def to_jax_policy(self, names_which_can_be_saved: list = [], names_which_can_be_
8080
case GradientCheckpointType.HIDDEN_STATE_WITH_OFFLOAD:
8181
return jax.checkpoint_policies.save_and_offload_only_these_names(
8282
names_which_can_be_saved=[],
83-
names_which_can_be_offloaded=["hidden_states","self_attn","cross_attn"],
83+
names_which_can_be_offloaded=["hidden_states", "self_attn", "cross_attn"],
8484
offload_src="device",
8585
offload_dst="pinned_host",
8686
)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ def __init__(
283283
precision=precision,
284284
attention_kernel=attention,
285285
dropout=dropout,
286-
residual_checkpoint_name='self_attn',
286+
residual_checkpoint_name="self_attn",
287287
)
288288

289289
# 1. Cross-attention
@@ -302,7 +302,7 @@ def __init__(
302302
precision=precision,
303303
attention_kernel=attention,
304304
dropout=dropout,
305-
residual_checkpoint_name='cross_attn',
305+
residual_checkpoint_name="cross_attn",
306306
)
307307
assert cross_attn_norm is True
308308
self.norm2 = FP32LayerNorm(rngs=rngs, dim=dim, eps=eps, elementwise_affine=True)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,9 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
131131
# This helps with loading sharded weights directly into the accelerators without fist copying them
132132
# all to one device and then distributing them, thus using low HBM memory.
133133
if restored_checkpoint:
134-
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
134+
if "params" in restored_checkpoint["wan_state"]: # if checkpointed with optimizer
135135
params = restored_checkpoint["wan_state"]["params"]
136-
else: # if not checkpointed with optimizer
136+
else: # if not checkpointed with optimizer
137137
params = restored_checkpoint["wan_state"]
138138
else:
139139
params = load_wan_transformer(

src/maxdiffusion/tests/wan_checkpointer_test.py

Lines changed: 100 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -16,107 +16,106 @@
1616

1717
from maxdiffusion.checkpointing.wan_checkpointer import WanCheckpointer, WAN_CHECKPOINT
1818

19+
1920
class WanCheckpointerTest(unittest.TestCase):
20-
def setUp(self):
21-
self.config = MagicMock()
22-
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
23-
self.config.dataset_type = "test_dataset"
24-
25-
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
26-
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
27-
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
28-
mock_manager = MagicMock()
29-
mock_manager.latest_step.return_value = None
30-
mock_create_manager.return_value = mock_manager
31-
32-
mock_pipeline_instance = MagicMock()
33-
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance
34-
35-
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
36-
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)
37-
38-
mock_manager.latest_step.assert_called_once()
39-
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
40-
self.assertEqual(pipeline, mock_pipeline_instance)
41-
self.assertIsNone(opt_state)
42-
self.assertIsNone(step)
43-
44-
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
45-
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
46-
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
47-
mock_manager = MagicMock()
48-
mock_manager.latest_step.return_value = 1
49-
metadata_mock = MagicMock()
50-
metadata_mock.wan_state = {}
51-
mock_manager.item_metadata.return_value = metadata_mock
52-
53-
restored_mock = MagicMock()
54-
restored_mock.wan_state = {'params': {}}
55-
restored_mock.wan_config = {}
56-
restored_mock.keys.return_value = ['wan_state', 'wan_config']
57-
def getitem_side_effect(key):
58-
if key == 'wan_state':
59-
return restored_mock.wan_state
60-
raise KeyError(key)
61-
restored_mock.__getitem__.side_effect = getitem_side_effect
62-
mock_manager.restore.return_value = restored_mock
63-
64-
mock_create_manager.return_value = mock_manager
65-
66-
mock_pipeline_instance = MagicMock()
67-
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
68-
69-
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
70-
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
71-
72-
mock_manager.restore.assert_called_once_with(
73-
directory=unittest.mock.ANY,
74-
step=1,
75-
args=unittest.mock.ANY
76-
)
77-
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
78-
self.assertEqual(pipeline, mock_pipeline_instance)
79-
self.assertIsNone(opt_state)
80-
self.assertEqual(step, 1)
81-
82-
@patch('maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager')
83-
@patch('maxdiffusion.checkpointing.wan_checkpointer.WanPipeline')
84-
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
85-
mock_manager = MagicMock()
86-
mock_manager.latest_step.return_value = 1
87-
metadata_mock = MagicMock()
88-
metadata_mock.wan_state = {}
89-
mock_manager.item_metadata.return_value = metadata_mock
90-
91-
restored_mock = MagicMock()
92-
restored_mock.wan_state = {'params': {}, 'opt_state': {'learning_rate': 0.001}}
93-
restored_mock.wan_config = {}
94-
restored_mock.keys.return_value = ['wan_state', 'wan_config']
95-
def getitem_side_effect(key):
96-
if key == 'wan_state':
97-
return restored_mock.wan_state
98-
raise KeyError(key)
99-
restored_mock.__getitem__.side_effect = getitem_side_effect
100-
mock_manager.restore.return_value = restored_mock
101-
102-
mock_create_manager.return_value = mock_manager
103-
104-
mock_pipeline_instance = MagicMock()
105-
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
106-
107-
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
108-
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
109-
110-
mock_manager.restore.assert_called_once_with(
111-
directory=unittest.mock.ANY,
112-
step=1,
113-
args=unittest.mock.ANY
114-
)
115-
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
116-
self.assertEqual(pipeline, mock_pipeline_instance)
117-
self.assertIsNotNone(opt_state)
118-
self.assertEqual(opt_state['learning_rate'], 0.001)
119-
self.assertEqual(step, 1)
21+
22+
def setUp(self):
23+
self.config = MagicMock()
24+
self.config.checkpoint_dir = "/tmp/wan_checkpoint_test"
25+
self.config.dataset_type = "test_dataset"
26+
27+
@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
28+
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline")
29+
def test_load_from_diffusers(self, mock_wan_pipeline, mock_create_manager):
30+
mock_manager = MagicMock()
31+
mock_manager.latest_step.return_value = None
32+
mock_create_manager.return_value = mock_manager
33+
34+
mock_pipeline_instance = MagicMock()
35+
mock_wan_pipeline.from_pretrained.return_value = mock_pipeline_instance
36+
37+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
38+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=None)
39+
40+
mock_manager.latest_step.assert_called_once()
41+
mock_wan_pipeline.from_pretrained.assert_called_once_with(self.config)
42+
self.assertEqual(pipeline, mock_pipeline_instance)
43+
self.assertIsNone(opt_state)
44+
self.assertIsNone(step)
45+
46+
@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
47+
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline")
48+
def test_load_checkpoint_no_optimizer(self, mock_wan_pipeline, mock_create_manager):
49+
mock_manager = MagicMock()
50+
mock_manager.latest_step.return_value = 1
51+
metadata_mock = MagicMock()
52+
metadata_mock.wan_state = {}
53+
mock_manager.item_metadata.return_value = metadata_mock
54+
55+
restored_mock = MagicMock()
56+
restored_mock.wan_state = {"params": {}}
57+
restored_mock.wan_config = {}
58+
restored_mock.keys.return_value = ["wan_state", "wan_config"]
59+
60+
def getitem_side_effect(key):
61+
if key == "wan_state":
62+
return restored_mock.wan_state
63+
raise KeyError(key)
64+
65+
restored_mock.__getitem__.side_effect = getitem_side_effect
66+
mock_manager.restore.return_value = restored_mock
67+
68+
mock_create_manager.return_value = mock_manager
69+
70+
mock_pipeline_instance = MagicMock()
71+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
72+
73+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
74+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
75+
76+
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
77+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
78+
self.assertEqual(pipeline, mock_pipeline_instance)
79+
self.assertIsNone(opt_state)
80+
self.assertEqual(step, 1)
81+
82+
@patch("maxdiffusion.checkpointing.wan_checkpointer.create_orbax_checkpoint_manager")
83+
@patch("maxdiffusion.checkpointing.wan_checkpointer.WanPipeline")
84+
def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_manager):
85+
mock_manager = MagicMock()
86+
mock_manager.latest_step.return_value = 1
87+
metadata_mock = MagicMock()
88+
metadata_mock.wan_state = {}
89+
mock_manager.item_metadata.return_value = metadata_mock
90+
91+
restored_mock = MagicMock()
92+
restored_mock.wan_state = {"params": {}, "opt_state": {"learning_rate": 0.001}}
93+
restored_mock.wan_config = {}
94+
restored_mock.keys.return_value = ["wan_state", "wan_config"]
95+
96+
def getitem_side_effect(key):
97+
if key == "wan_state":
98+
return restored_mock.wan_state
99+
raise KeyError(key)
100+
101+
restored_mock.__getitem__.side_effect = getitem_side_effect
102+
mock_manager.restore.return_value = restored_mock
103+
104+
mock_create_manager.return_value = mock_manager
105+
106+
mock_pipeline_instance = MagicMock()
107+
mock_wan_pipeline.from_checkpoint.return_value = mock_pipeline_instance
108+
109+
checkpointer = WanCheckpointer(self.config, WAN_CHECKPOINT)
110+
pipeline, opt_state, step = checkpointer.load_checkpoint(step=1)
111+
112+
mock_manager.restore.assert_called_once_with(directory=unittest.mock.ANY, step=1, args=unittest.mock.ANY)
113+
mock_wan_pipeline.from_checkpoint.assert_called_with(self.config, mock_manager.restore.return_value)
114+
self.assertEqual(pipeline, mock_pipeline_instance)
115+
self.assertIsNotNone(opt_state)
116+
self.assertEqual(opt_state["learning_rate"], 0.001)
117+
self.assertEqual(step, 1)
118+
120119

121120
if __name__ == "__main__":
122-
unittest.main()
121+
unittest.main()

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def start_training(self):
213213
pipeline, opt_state, step = self.load_checkpoint()
214214
restore_args = {}
215215
if opt_state and step:
216-
restore_args = {"opt_state": opt_state, "step":step}
216+
restore_args = {"opt_state": opt_state, "step": step}
217217
del opt_state
218218
if self.config.enable_ssim:
219219
# Generate a sample before training to compare against generated sample after training.
@@ -285,28 +285,30 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr
285285
if writer:
286286
writer.add_scalar("learning/eval_loss", final_eval_loss, step)
287287

288-
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args:dict={}):
288+
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator, restore_args: dict = {}):
289289
mesh = pipeline.mesh
290290
graphdef, params, rest_of_state = nnx.split(pipeline.transformer, nnx.Param, ...)
291291

292292
with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
293293
state = TrainState.create(
294-
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state)
294+
apply_fn=graphdef.apply, params=params, tx=optimizer, graphdef=graphdef, rest_of_state=rest_of_state
295+
)
295296
if restore_args:
296297
step = restore_args.get("step", 0)
297298
max_logging.log(f"Restoring optimizer and resuming from step {step}")
298-
state.replace(opt_state=restore_args.get("opt_state"), step = restore_args.get("step", 0))
299+
state.replace(opt_state=restore_args.get("opt_state"), step=restore_args.get("step", 0))
299300
del restore_args["opt_state"]
300301
del optimizer
301302
state = jax.tree.map(_to_array, state)
302303
state_spec = nnx.get_partition_spec(state)
303304
state = jax.lax.with_sharding_constraint(state, state_spec)
304305
state_shardings = nnx.get_named_sharding(state, mesh)
305306
if jax.process_index() == 0 and restore_args:
306-
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
307-
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
308-
max_logging.log(pretty_string)
309-
max_logging.log("------------------------------------------------")
307+
max_logging.log("--- Optimizer State Sharding Spec (opt_state) ---")
308+
pretty_string = pprint.pformat(state_spec.opt_state, indent=4, width=60)
309+
max_logging.log(pretty_string)
310+
max_logging.log("------------------------------------------------")
311+
max_utils.delete_pytree(params)
310312
data_shardings = self.get_data_shardings(mesh)
311313
eval_data_shardings = self.get_eval_data_shardings(mesh)
312314

@@ -349,9 +351,9 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
349351
last_profiling_step = np.clip(
350352
first_profiling_step + self.config.profiler_steps - 1, first_profiling_step, self.config.max_train_steps - 1
351353
)
352-
if restore_args.get("step",0):
353-
max_logging.log(f"Resuming training from step {step}")
354-
start_step = restore_args.get("step",0)
354+
if restore_args.get("step", 0):
355+
max_logging.log(f"Resuming training from step {step}")
356+
start_step = restore_args.get("step", 0)
355357
per_device_tflops, _, _ = WanTrainer.calculate_tflops(pipeline)
356358
scheduler_state = pipeline.scheduler_state
357359
example_batch = load_next_batch(train_data_iterator, None, self.config)

0 commit comments

Comments
 (0)