-
Notifications
You must be signed in to change notification settings - Fork 72
/
Copy pathsharding_utils.py
82 lines (59 loc) · 2.37 KB
/
sharding_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Utilities for dealing with sharding in JAX."""
import jax
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
def get_mesh() -> jax.sharding.Mesh:
"""Creates a mesh from all available GPUs.
Here, we simply create a one-dimensional mesh."""
return jax.sharding.Mesh(jax.devices(), ("batch",))
def get_replicated_sharding(mesh=None):
"""Returns a sharding spec that replicates data across all devices."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec())
def shard_replicated(x, mesh=None):
"""Shards a tensor across all devices."""
if mesh is None:
mesh = get_mesh()
return jax.tree.map(
lambda x: jax.device_put(x, get_replicated_sharding(mesh)), x)
def get_naive_sharding_spec(mesh=None):
"""Returns a sharding spec that shards data along the first axis."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec("batch"))
def get_naive_sharding(x, mesh=None):
"""Given a 1D mesh and a tensor, try to shard along the appropriate axis."""
if mesh is None:
mesh = get_mesh()
grid_size = mesh.shape["batch"]
if len(x.shape) > 0 and x.shape[0] % grid_size == 0:
return NamedSharding(mesh, PartitionSpec("batch"))
else:
return NamedSharding(mesh, PartitionSpec())
def shard_params(params, mesh=None):
"""Shards a parameter tree across all devices
with naive sharding (see get_naive_sharding)."""
if mesh is None:
mesh = get_mesh()
return jax.tree.map(lambda x: jax.device_put(x, get_naive_sharding(x)),
params)
def shard_naive(x, mesh=None):
return shard_params(x, mesh)
def get_naive_sharding_tree(input_tree, mesh=None):
if mesh is None:
mesh = get_mesh()
return jax.tree.map(lambda x: get_naive_sharding(x, mesh), input_tree)
def get_sharding_tree(params, mesh=None):
"""Returns a sharding tree for a parameter tree."""
return jax.tree.map(lambda x: get_naive_sharding(x, mesh), params)
def get_empty_sharding(mesh=None):
"""Returns a sharding spec that replicates data across all devices."""
if mesh is None:
mesh = get_mesh()
return NamedSharding(mesh, PartitionSpec())
def disp_shard_info(x: jax.Array):
"""Displays shard info of a jax array."""
for shard in x.addressable_shards:
print(f"shard.device: {shard.device}, index: {shard.index}, replica_id:"
f" {shard.replica_id}.\n")