Skip to content

FSDP with custom state_dict() #21124

@LiuTaowen-Tony

Description

@LiuTaowen-Tony

Bug description

Checkpointing will fail when save partial checkpoint with FSDP

import lightning as L

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()

        # This model only trains the decoder, we don't save the encoder
        self.encoder = from_pretrained(...).requires_grad_(False)
        self.trainable_head = TrainableHead()

        # Set to False because we only care about the decoder
        self.strict_loading = False

    def state_dict(self):
        # Don't save the encoder, it is not being trained
        return self.trainable_head.state_dict()

    trainer = L.Trainer(
        strategy=FSDPStrategy(auto_wrap_policy={EncoderLayer}),
    )

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 599, in _fit_impl
[rank0]:     self._run(model, ckpt_path=ckpt_path)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1012, in _run
[rank0]:     results = self._run_stage()
[rank0]:               ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1056, in _run_stage
[rank0]:     self.fit_loop.run()
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank0]:     self.advance()
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 455, in advance
[rank0]:     self.epoch_loop.run(self._data_fetcher)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 153, in run
[rank0]:     self.on_advance_end(data_fetcher)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 394, in on_advance_end
[rank0]:     self.val_loop.run()
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py", line 179, in _decorator
[rank0]:     return loop_run(self, *args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 152, in run
[rank0]:     return self.on_run_end()
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 309, in on_run_end
[rank0]:     self._on_evaluation_end()
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 354, in _on_evaluation_end
[rank0]:     call._call_callback_hooks(trainer, hook_name, *args, **kwargs)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 227, in _call_callback_hooks
[rank0]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 338, in on_validation_end
[rank0]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 391, in _save_topk_checkpoint
[rank0]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 719, in _save_none_monitor_checkpoint
[rank0]:     self._save_checkpoint(trainer, filepath)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 394, in _save_checkpoint
[rank0]:     trainer.save_checkpoint(filepath, self.save_weights_only)
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1396, in save_checkpoint
[rank0]:     checkpoint = self._checkpoint_connector.dump_checkpoint(weights_only)
[rank0]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 445, in dump_checkpoint
[rank0]:     "state_dict": self._get_lightning_module_state_dict(),
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/checkpoint_connector.py", line 506, in _get_lightning_module_state_dict
[rank0]:     return self.trainer.strategy.lightning_module_state_dict()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/lightning/pytorch/strategies/fsdp.py", line 511, in lightning_module_state_dict
[rank0]:     return self.model.state_dict()
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/nn/modules/module.py", line 2266, in state_dict
[rank0]:     hook_result = hook(self, destination, prefix, local_metadata)
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 714, in _post_state_dict_hook
[rank0]:     processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type](
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 342, in _full_post_state_dict_hook
[rank0]:     return _common_unshard_post_state_dict_hook(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/tl2020/.conda/envs/kv-compressor/lib/python3.12/site-packages/torch/distributed/fsdp/_state_dict_utils.py", line 222, in _common_unshard_post_state_dict_hook
[rank0]:     assert fqn in state_dict, (
[rank0]:            ^^^^^^^^^^^^^^^^^
[rank0]: AssertionError: FSDP assumes patched_llama.model.model.embed_tokens.weight is in the state_dict but the state_dict only has odict_keys(

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcheckpointingRelated to checkpointingstrategy: fsdpFully Sharded Data Parallelver: 2.5.xwaiting on authorWaiting on user action, correction, or update

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions