Skip to content

[PoC] Enable flexible different layout for same mesh via a util function #1550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 195 additions & 0 deletions torchtitan/distributed/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass
from functools import cached_property

Expand All @@ -27,6 +28,7 @@ class ParallelDims:
world_size: int

_world_mesh: DeviceMesh = None
intermediate_num: int = 0

def __post_init__(self):
self._validate()
Expand Down Expand Up @@ -175,6 +177,199 @@ def _build_mesh_without_ep(self) -> DeviceMesh:

return mesh

def _merge_mesh_config(
self,
config_l: dict[str, int],
config_r: dict[str, int],
inter_flatten_names_map_list: list[dict[str, list[str]]],
) -> tuple[list[int], list[str]]:
final_mesh_shape: list[int] = []
final_mesh_dim_names: list[str] = []
it_l = iter(config_l.items())
it_r = iter(config_r.items())
val_l: int = 0
val_r: int = 0
name_l: str = ""
name_r: str = ""
while True:
if not val_l and (next_item := next(it_l, None)):
name_l, val_l = next_item
if not val_r and (next_item := next(it_r, None)):
name_r, val_r = next_item
if not bool(val_l) and not bool(val_r):
return final_mesh_shape, final_mesh_dim_names
if bool(val_l) ^ bool(val_r):
raise ValueError("Cannot merge two mesh configuration")
if val_l == val_r:
# Case one: both mesh dim names are the same and same dim shape
if val_l == val_r:
final_mesh_shape.append(val_l)
final_mesh_dim_names.append(name_l)
# Case two: both mesh dim names happen to have same dim shape, we unflatten it into (n, 1)
else:
final_mesh_shape.append(val_l)
final_mesh_shape.append(1)
final_mesh_dim_names.append(
f"inter_dim_{ParallelDims.intermediate_num}"
)
ParallelDims.intermediate_num += 1
final_mesh_dim_names.append(
f"inter_dim_{ParallelDims.intermediate_num}"
)
inter_flatten_names_map_list.append(
{
name_l: [
f"inter_dim_{ParallelDims.intermediate_num}",
f"inter_dim_{ParallelDims.intermediate_num-1}",
],
name_r: [
f"inter_dim_{ParallelDims.intermediate_num}",
f"inter_dim_{ParallelDims.intermediate_num-1}",
],
}
)
ParallelDims.intermediate_num += 1
val_l = 0
val_r = 0
else:
# Case three: we need to further unflatten one of the mesh dim.
# If left and right (after divided by their gcd) are coprime with each other,
# there is no way we can unflatten into a common mesh configuration.
gcd = math.gcd(val_l, val_r)
if gcd != min(val_l, val_r):
raise ValueError(
"Cannot merge two mesh configuration because there are dims which are coprime with each other."
)
elif gcd == val_l:
final_mesh_shape.append(val_l)
final_mesh_dim_names.append(name_l)
val_r = val_r // gcd
inter_flatten_names_map_list.append(
{
name_r: [
f"inter_dim_{ParallelDims.intermediate_num}",
name_l,
],
}
)
name_r = f"inter_dim_{ParallelDims.intermediate_num}"
ParallelDims.intermediate_num += 1
else:
final_mesh_shape.append(val_r)
final_mesh_dim_names.append(name_r)
val_l = val_l // gcd
inter_flatten_names_map_list.append(
{
name_r: [
f"inter_dim_{ParallelDims.intermediate_num}",
name_l,
],
}
)
name_r = f"inter_dim_{ParallelDims.intermediate_num}"
ParallelDims.intermediate_num += 1

def _build_mesh_with_mega_ctr(
self,
device_type: str,
mesh_configurations: list[dict[str, int]],
flatten_names_map: list[dict[str, list[str]]],
) -> DeviceMesh:
"""
Build a mesh with multiple mesh configurations and flatten names map.
The flatten names map is a list of maps, each map is a mapping from a mesh dim name to a list of flatten names.
One possible example can be:
mesh_configurations = {
{
"pp": 2,
"cp": 4,
"dp_shard": 4,
"dp_replicate": 2,
"tp": 8,
},
{
"pp": 2,
"dp": 4,
"ep": 16,
"ep_tp": 4,
}
}
flatten_names_map = {
"dp": {
"dp_shard",
"dp_replicate",
}
"loss_update": {
"cp",
"dp",
}
"dp_shard_cp": {
"cp",
"dp_shard",
}
...
}
"""
logger.info(
f"Building a mesh with {len(mesh_configurations)} layouts with {mesh_configurations}, {flatten_names_map}"
)
assert (
len(mesh_configurations) >= 1
), "mesh_configurations should have at least one map."
assert (
len({len(c) for c in mesh_configurations}) == 1
), "All maps within mesh_configurations should be equal."
flatten_names = {
name
for flatten_map in flatten_names_map
for names in flatten_map.values()
for name in names
}
valid_names = {
name for flatten_map in flatten_names_map for name in flatten_map.keys()
} | {key for c in mesh_configurations for key in c.keys()}
assert (
flatten_names <= valid_names
), f"Invalid dim names {flatten_names - valid_names} are specified in flatten_names_map"

# For now we only support that all mesh configurations can be unflatten into one common mesh configuration
# For example, config [8, 4, 4] and [4, 2, 2, 8] can be all flatten from [4, 2, 2, 2, 4] but [5, 2] and [2, 5] cannot.
final_mesh_shape: list[int] = []
final_mesh_dim_names: list[str] = []
inter_flatten_names_map_list: list[dict[str, list[str]]] = []
if len(mesh_configurations) == 1:
final_mesh_shape = list(mesh_configurations[0].values())
final_mesh_dim_names = list(mesh_configurations[0].keys())
else:
config_iter = iter(mesh_configurations)
first = next(config_iter)
second = next(config_iter)
final_mesh_shape, final_mesh_dim_names = self._merge_mesh_config(
first, second, inter_flatten_names_map_list
)
while next_one := next(config_iter, None):
final_mesh_shape, final_mesh_dim_names = self._merge_mesh_config(
dict(zip(final_mesh_dim_names, final_mesh_shape)),
next_one,
inter_flatten_names_map_list,
)

logger.info(
f"Building intermediate {len(final_mesh_shape)}-D device mesh with {final_mesh_dim_names}, {final_mesh_shape}"
)
mesh = init_device_mesh(
device_type, final_mesh_shape, mesh_dim_names=final_mesh_dim_names
)
inter_flatten_names_map_list.reverse()
for flatten_map in inter_flatten_names_map_list:
for key, din_names in flatten_map.items():
mesh[tuple(din_names)]._flatten(mesh_dim_name=key)
for flatten_map in flatten_names_map:
for key, din_names in flatten_map.items():
mesh[tuple(din_names)]._flatten(mesh_dim_name=key)

return mesh

@property
def world_mesh(self) -> str:
# doing late init so ParallelDims can still be used as a lightweight
Expand Down
Loading