We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e55eef4 commit f3daad1Copy full SHA for f3daad1
examples/llama/2d_llama.py
@@ -6,12 +6,15 @@
6
from torch.distributed._tensor import init_device_mesh
7
8
9
+# Utility
10
def modify_view(
11
gm: torch.fx.GraphModule,
12
tp: int
13
):
14
"""
- Adjust dimension size of view ops to make them compatible with tensor parallelism.
15
+ Adjust dimension size of view ops to make them compatible with tensor
16
+ parallelism. For example, when TP is 4, we need to adjust `num_heads` from
17
+ 32 to 8. This is needed for attention layers.
18
19
for node in gm.graph.nodes:
20
if node.op == "call_method" and (
0 commit comments