Open
Description
Hi,
I tried running TP + FSDP + MXFP8 with torchtitan on a two GB200 node setup (2x4GPUs) but experienced this error in torchao during compile:
[rank0]: torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors: call_function <function produce_trampoline_autograd_apply.<locals>.trampoline_autograd_apply at 0x4003dde9a980>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(8, 8192, 8192), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3], mesh_dim_names=('tp',)), placements=(Replicate(),)), DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(2048, 8192), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3], mesh_dim_names=('tp',)), placements=(Shard(dim=0),)), torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn, 32, <MXGemmKernelChoice.CUBLAS: 'cublas'>, True), **{}): got RuntimeError('Attempting to broadcast a dimension of length 4194304 at -1! Mismatching argument at index 2 had torch.Size([4194304]); but expected shape should be broadcastable to [2097152]')
[rank0]: from user code:
[rank0]: File "/home/mreso/venv/lib/python3.12/site-packages/torch/distributed/algorithms/_checkpoint/checkpoint_wrapper.py", line 171, in forward
[rank0]: return self.checkpoint_fn( # type: ignore[misc]
[rank0]: File "/home/mreso/torchtitan/torchtitan/models/llama3/model.py", line 365, in forward
[rank0]: h = x + self.attention(self.attention_norm(x), freqs_cis)
[rank0]: File "/home/mreso/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1870, in _call_impl
[rank0]: return inner()
[rank0]: File "/home/mreso/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1821, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: File "/home/mreso/torchtitan/torchtitan/models/llama3/model.py", line 238, in forward
[rank0]: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
[rank0]: File "/home/mreso/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1870, in _call_impl
[rank0]: return inner()
[rank0]: File "/home/mreso/venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1821, in inner
[rank0]: result = forward_call(*args, **kwargs)
[rank0]: File "/home/mreso/ao/torchao/prototype/mx_formats/mx_linear.py", line 209, in forward
[rank0]: y = mx_mm.apply(
[rank0]: Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
Full error log
Repro steps:
Run job using torchtitan default slurm script with the following changes:
-#SBATCH --ntasks=4
+#SBATCH --ntasks=2
-#SBATCH --nodes=4
+#SBATCH --nodes=2
-#SBATCH --gpus-per-task=8
+#SBATCH --gpus-per-task=4
-CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama/train_configs/llama3_8b.toml"}
+CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_70b.toml"}
-srun torchrun --nnodes 4 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE}
+srun torchrun --nnodes 2 --nproc_per_node 4 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" ./torchtitan/train.py --job.config_file ${CONFIG_FILE} --parallelism.tensor_parallel_degree=4 --training.compile --training.local_batch_size=8 --model.converters mx --mx.recipe_name "mxfp8"
Env:
PyTorch version: 2.8.0.dev20250609+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04 LTS (aarch64)
GCC version: (GCC) 14.1.0
Clang version: Could not collect
CMake version: version 3.28.3
Libc version: glibc-2.39
Python version: 3.12.3 (main, Feb 4 2025, 14:48:35) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-1024-nvidia-64k-aarch64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 12.0.140
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: HGX GB200
GPU 1: HGX GB200
GPU 2: HGX GB200
GPU 3: HGX GB200
Nvidia driver version: 570.124.06
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: aarch64
CPU op-mode(s): 64-bit
Byte Order: Little Endian
CPU(s): 144
On-line CPU(s) list: 0-143
Vendor ID: ARM
Model name: Neoverse-V2
Model: 0
Thread(s) per core: 1
Core(s) per socket: 72
Socket(s): 2
Stepping: r0p0
Frequency boost: disabled
CPU(s) scaling MHz: 100%
CPU max MHz: 3393.0000
CPU min MHz: 81.0000
BogoMIPS: 2000.00
Flags: fp asimd evtstrm aes pmull sha1 sha2 crc32 atomics fphp asimdhp cpuid asimdrdm jscvt fcma lrcpc dcpop sha3 sm3 sm4 asimddp sha512 sve asimdfhm dit uscat ilrcpc flagm ssbs sb paca pacg dcpodp sve2 sveaes svep
mull svebitperm svesha3 svesm4 flagm2 frint svei8mm svebf16 i8mm bf16 dgh bti
L1d cache: 9 MiB (144 instances)
L1i cache: 9 MiB (144 instances)
L2 cache: 144 MiB (144 instances)
L3 cache: 228 MiB (2 instances)
NUMA node(s): 34
NUMA node0 CPU(s): 0-71
NUMA node1 CPU(s): 72-143
NUMA node2 CPU(s):
NUMA node3 CPU(s):
NUMA node4 CPU(s):
NUMA node5 CPU(s):
NUMA node6 CPU(s):
NUMA node7 CPU(s):
NUMA node8 CPU(s):
NUMA node9 CPU(s):
NUMA node10 CPU(s):
NUMA node11 CPU(s):
NUMA node12 CPU(s):
NUMA node13 CPU(s):
NUMA node14 CPU(s):
NUMA node15 CPU(s):
NUMA node16 CPU(s):
NUMA node17 CPU(s):
NUMA node18 CPU(s):
NUMA node19 CPU(s):
NUMA node20 CPU(s):
NUMA node21 CPU(s):
NUMA node22 CPU(s):
NUMA node23 CPU(s):
NUMA node24 CPU(s):
NUMA node25 CPU(s):
NUMA node26 CPU(s):
NUMA node27 CPU(s):
NUMA node28 CPU(s):
NUMA node29 CPU(s):
NUMA node30 CPU(s):
NUMA node31 CPU(s):
NUMA node32 CPU(s):
NUMA node33 CPU(s):
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; __user pointer sanitization
Vulnerability Spectre v2: Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] Could not collect
Manually added:
pytorch-triton==3.3.1+gitc8757738
torch==2.8.0.dev20250609+cu128
torchaudio==2.8.0.dev20250610
torchdata==0.11.0
torchvision==0.23.0.dev20250610
torchao==5239ce7e64ff71f5b3f8affb95a137fe7200a6a0
torchtitan==b7c7ed7167fa72485d987fa69af53538ee5be900
Metadata
Metadata
Assignees
Labels
No labels