Open
Description
When I fp8 quantize a model and then shard it using FSDP2, it reports an error:
[rank1]: Traceback (most recent call last):
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 461, in <module>
[rank1]: generate(args)
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/generate.py", line 375, in generate
[rank1]: wan_i2v = wan.WanI2V(
[rank1]: ^^^^^^^^^^^
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/image2video.py", line 218, in __init__
[rank1]: self.model = shard_dit_fn(self.model, param_dtype=torch.float8_e4m3fn)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/mnt/teams/algo-teams/shared/code/wanx-inference/wan/distributed/fsdp.py", line 112, in shard_dit_model
[rank1]: fully_shard_with_ignore_param(block, mesh=pm.get_dp_with_cp_mesh(), reshard_after_forward=True, mp_policy=mixed_fsdp2, ignored_params=ignored_states_set)
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/_composable/contract.py", line 125, in wrapper
[rank1]: updated = func(inp_module, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/mnt/teams/algo-teams/shared/pytorch_distributed_examples/src/tu_pth_dist/fsdp_compat.py", line 200, in fully_shard_with_ignore_param
[rank1]: state._fsdp_param_group = FSDPParamGroup(
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py", line 132, in __init__
[rank1]: FSDPParam(
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 239, in __init__
[rank1]: self._init_sharded_param(param, device, shard_placement_fn)
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py", line 368, in _init_sharded_param
[rank1]: chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py", line 124, in _chunk_with_empty
[rank1]: chunks = list(torch.chunk(tensor, num_chunks, dim=dim))
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 436, in _dispatch__torch_function__
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/miniconda/envs/wan21/lib/python3.12/site-packages/torchao/utils.py", line 455, in _dispatch__torch_dispatch__
[rank1]: raise NotImplementedError(
[rank1]: NotImplementedError: LinearActivationQuantizedTensor dispatch: attempting to run unimplemented operator/function: func=<OpOverload(op='aten.split', overload='Tensor')>, types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>,), arg_types=(<class 'torchao.quantization.linear_activation_quantized_tensor.LinearActivationQuantizedTensor'>, <class 'int'>), kwarg_types={}
I can see that there is no aten.split
in https://github.com/pytorch/ao/blob/ab3792e3d91e04f85992a659c1664a6a1a6d733c/torchao/quantization/linear_activation_quantized_tensor.py . Could anyone provide an implementation for it?