-**Loading state dicts**: We initialize the model under meta device and call ``fully_shard`` to convert ``model.parameters()`` from plain ``torch.Tensor`` to DTensor. After reading the full state dict from torch.load, we can call `distributed_tensor <https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.distribute_tensor>`_ to convert plain ``torch.Tensor`` into DTensor, using the same placements and device mesh from ``model.state_dict()``. Finally we can call `model.load_state_dict <https://docs.pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict>`_ to load DTensor state dicts into the model.
0 commit comments