Skip to content

fp8 quantization with FSDP2 error #1929

Open
@happynear

Description

@happynear

When I fp8 quantize a model and then shard it using FSDP2, it reports an error:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 461, in <module>
[rank1]:     generate(args)
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 375, in generate
[rank1]:     wan_i2v = wan.WanI2V(
[rank1]:               ^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/image2video.py", line 218, in __init__
[rank1]:     self.model = shard_dit_fn(self.model, param_dtype=torch.float8_e4m3fn)
[rank1]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/distributed/fsdp.py", line 112, in shard_dit_model
[rank1]:     fully_shard_with_ignore_param(block, mesh=pm.get_dp_with_cp_mesh(), reshard_after_forward=True, mp_policy=mixed_fsdp2, ignored_params=ignored_states_set)
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]:     updated = func(inp_module, *args, **kwargs)
[rank1]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/mnt/teams/algo-teams/shared/pytorch_distributed_examples/src/tu_pth_dist/fsdp_compat.py", line 200, in fully_shard_with_ignore_param
[rank1]:     state._fsdp_param_group = FSDPParamGroup(
[rank1]:                               ^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 132, in __init__
[rank1]:     FSDPParam(
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 239, in __init__
[rank1]:     self._init_sharded_param(param, device, shard_placement_fn)
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 368, in _init_sharded_param
[rank1]:     chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim)
[rank1]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank1]:     chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]:                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank1]:     return func(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank1]:     raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}

I can see that there is no aten.split in https://github.com/pytorch/ao/blob/ab3792e3d91e04f85992a659c1664a6a1a6d733c/torchao/quantization/linear_activation_quantized_tensor.py . Could anyone provide an implementation for it?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions