Skip to content

Commit f3daad1

Browse files
committed
Comments
1 parent e55eef4 commit f3daad1

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

examples/llama/2d_llama.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,15 @@
66
from torch.distributed._tensor import init_device_mesh
77

88

9+
# Utility
910
def modify_view(
1011
gm: torch.fx.GraphModule,
1112
tp: int
1213
):
1314
"""
14-
Adjust dimension size of view ops to make them compatible with tensor parallelism.
15+
Adjust dimension size of view ops to make them compatible with tensor
16+
parallelism. For example, when TP is 4, we need to adjust `num_heads` from
17+
32 to 8. This is needed for attention layers.
1518
"""
1619
for node in gm.graph.nodes:
1720
if node.op == "call_method" and (

0 commit comments

Comments
 (0)