Skip to content

Can we support outputting checkpoints directly in .pt format? #1177

Open
@andrewor14

Description

@andrewor14

Today we need to do an extra conversion step according to this README: https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md

python -m torch.distributed.checkpoint.format_utils dcp_to_torch outputs/checkpoint/step-100 /tmp/checkpoint.pt

I think we should provide an option for users to specify which format to output their checkpoints instead, and call this function in torchtitan for users as part of outputting the checkpoint.


Bonus: This conversion step actually fails today if we used FP8 training. I had to manually add the following line to the dcp_to_torch function as a hack to get it to work:

torch.serialization.add_safe_globals([torchao.float8.fsdp_utils.WeightWithDynamicFloat8CastTensor])

It would be great if we can just either implicitly add the safe globals when we output the checkpoint in torchtitan, or simply remove this WeightWithDynamicFloat8CastTensor from the BC surface.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions