Skip to content

Commit bcaf757

Browse files
committed
new line added
1 parent 04dfff3 commit bcaf757

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/maxdiffusion/checkpointing/wan_checkpointer2_2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
6060
return None, None
6161
max_logging.log(f"Loading WAN checkpoint from step {step}")
6262
metadatas = self.checkpoint_manager.item_metadata(step)
63-
63+
6464
low_noise_transformer_metadata = metadatas.low_noise_transformer_state
6565
abstract_tree_structure_low_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, low_noise_transformer_metadata)
6666
low_params_restore = ocp.args.PyTreeRestore(
@@ -69,7 +69,7 @@ def load_wan_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[dic
6969
abstract_tree_structure_low_params,
7070
)
7171
)
72-
72+
7373
high_noise_transformer_metadata = metadatas.high_noise_transformer_state
7474
abstract_tree_structure_high_params = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, high_noise_transformer_metadata)
7575
high_params_restore = ocp.args.PyTreeRestore(

src/maxdiffusion/tests/wan_checkpointer2_2_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,4 @@ def test_load_checkpoint_with_optimizer(self, mock_wan_pipeline, mock_create_man
110110

111111

112112
if __name__ == "__main__":
113-
unittest.main()
113+
unittest.main()

0 commit comments

Comments
 (0)