Open
Description
Today we need to do an extra conversion step according to this README: https://github.com/pytorch/torchtitan/blob/main/docs/checkpoint.md
python -m torch.distributed.checkpoint.format_utils dcp_to_torch outputs/checkpoint/step-100 /tmp/checkpoint.pt
I think we should provide an option for users to specify which format to output their checkpoints instead, and call this function in torchtitan for users as part of outputting the checkpoint.
Bonus: This conversion step actually fails today if we used FP8 training. I had to manually add the following line to the dcp_to_torch
function as a hack to get it to work:
torch.serialization.add_safe_globals([torchao.float8.fsdp_utils.WeightWithDynamicFloat8CastTensor])
It would be great if we can just either implicitly add the safe globals when we output the checkpoint in torchtitan, or simply remove this WeightWithDynamicFloat8CastTensor
from the BC surface.