-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallelver: 2.5.xwaiting on authorWaiting on user action, correction, or updateWaiting on user action, correction, or update
Description
Bug description
Checkpointing will fail when save partial checkpoint with FSDP
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# This model only trains the decoder, we don't save the encoder
self.encoder = from_pretrained(...).requires_grad_(False)
self.trainable_head = TrainableHead()
# Set to False because we only care about the decoder
self.strict_loading = False
def state_dict(self):
# Don't save the encoder, it is not being trained
return self.trainable_head.state_dict()
trainer = L.Trainer(
strategy=FSDPStrategy(auto_wrap_policy={EncoderLayer}),
)
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
[rank0]: self._run(model, ckpt_path=ckpt_path)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank0]: results = self._run_stage()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
[rank0]: self.fit_loop.run()
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank0]: self.advance()
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
[rank0]: self.epoch_loop.run(self._data_fetcher)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 153, in run
[rank0]: self.on_advance_end(data_fetcher)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 394, in on_advance_end
[rank0]: self.val_loop.run()
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
[rank0]: return loop_run(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 152, in run
[rank0]: return self.on_run_end()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 309, in on_run_end
[rank0]: self._on_evaluation_end()
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 354, in _on_evaluation_end
[rank0]: call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 227, in _call_callback_hooks
[rank0]: fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 338, in on_validation_end
[rank0]: self._save_topk_checkpoint(trainer, monitor_candidates)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 391, in _save_topk_checkpoint
[rank0]: self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 719, in _save_none_monitor_checkpoint
[rank0]: self._save_checkpoint(trainer, filepath)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 394, in _save_checkpoint
[rank0]: trainer.save_checkpoint(filepath, self.save_weights_only)
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1396, in save_checkpoint
[rank0]: checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
[rank0]: "state_dict": self._get_lightning_module_state_dict(),
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 506, in _get_lightning_module_state_dict
[rank0]: return self.trainer.strategy.lightning_module_state_dict()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/strategies/fsdp.py", line 511, in lightning_module_state_dict
[rank0]: return self.model.state_dict()
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2266, in state_dict
[rank0]: hook_result = hook(self, destination, prefix, local_metadata)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 714, in _post_state_dict_hook
[rank0]: processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 342, in _full_post_state_dict_hook
[rank0]: return _common_unshard_post_state_dict_hook(
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 222, in _common_unshard_post_state_dict_hook
[rank0]: assert fqn in state_dict, (
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: FSDP assumes patched_llama.model.model.embed_tokens.weight is in the state_dict but the state_dict only has odict_keys(
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
cc @lantiga
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingcheckpointingRelated to checkpointingRelated to checkpointingstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallelver: 2.5.xwaiting on authorWaiting on user action, correction, or updateWaiting on user action, correction, or update