Skip to content

Commit 776c1c4

Browse files
authored
Add shard on batch mode. Als update version of torchxla2 (#80)
* Add shard on batch mode. Als update version of torchxla2 * Address comments
1 parent 648bf48 commit 776c1c4

15 files changed

+144
-91
lines changed

benchmarks/prefill_offline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import functools
1818
import humanize
1919

20+
# pylint: disable-next=all
2021
from absl import app
2122
from absl import flags
2223
import numpy as np

benchmarks/run_offline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import logging
1616
import os
1717
import time
18+
# pylint: disable-next=all
1819
from absl import app
1920
from absl import flags
2021

install_everything.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
TORCHXLA_TAG=jetstream-pytorch
15+
TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
1616
JETSTREAM_TAG=v0.2.1
1717

1818
# Uninstall existing jax

jetstream_pt/cache_manager.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import torch_xla2
1615
import jax
1716
import jax.numpy as jnp
1817
import torch
18+
from jetstream_pt import torchjax
1919

2020

2121
# pylint: disable-next=all
@@ -49,9 +49,7 @@ def update(self, key, value):
4949
self.cache_v = value
5050
if self.kv_quantize: # pretend to be quantized
5151
bsz, _, seq, _ = key.shape
52-
ones = torch_xla2.tensor.wrap(
53-
jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16)
54-
)
52+
ones = torchjax.to_torch(jnp.ones((bsz, 1, seq, 1), dtype=jnp.bfloat16))
5553
return key, value, ones, ones
5654

5755
return key, value
@@ -64,15 +62,15 @@ def state(self):
6462
# pylint: disable-next=all
6563
def KVCachePrefill_flatten(cache):
6664
return (
67-
torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)),
65+
torchjax.from_torch((cache.cache_k, cache.cache_v)),
6866
cache.kv_quantize,
6967
)
7068

7169

7270
# pylint: disable-next=all
7371
def KVCachePrefill_unflatten(auxdata, data):
7472
cache = KVCachePrefill(auxdata)
75-
cache_k, cache_v = torch_xla2.tensor.wrap(data)
73+
cache_k, cache_v = torchjax.from_torch(data)
7674
cache.cache_k = cache_k
7775
cache.cache_v = cache_v
7876

@@ -102,7 +100,7 @@ def __init__(
102100

103101
def update(self, key, value):
104102
"""Update kv cache"""
105-
keyj, valuej = torch_xla2.tensor.unwrap((key, value))
103+
keyj, valuej = torchjax.to_torch((key, value))
106104
# pylint: disable-next=all
107105
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
108106
# pylint: disable-next=all
@@ -112,30 +110,30 @@ def update(self, key, value):
112110
def state(self):
113111
"""Get kv cache state"""
114112
# pylint: disable-next=all
115-
return self.cache_k._elem, self.cache_v._elem
113+
return self.cache_k.jax(), self.cache_v.jax()
116114

117115
@classmethod
118116
def empty(cls, shape, device, bf16_enable):
119117
"""Create empty kv caches"""
120118
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
121119
k = jnp.zeros(shape, device=device, dtype=default_dtype)
122120
v = jnp.zeros(shape, device=device, dtype=default_dtype)
123-
k, v = torch_xla2.tensor.wrap((k, v))
121+
k, v = torchjax.to_torch((k, v))
124122
return cls(k, v, 0, device)
125123

126124

127125
# pylint: disable-next=all
128126
def KVCacheGenerate_flatten(cache):
129-
return torch_xla2.tensor.unwrap((cache.cache_k, cache.cache_v)), (
130-
cache.pos,
131-
cache.sharding,
127+
return ((cache.cache_k.jax(), cache.cache_v.jax())), (
128+
cache.pos.jax(),
129+
cache.sharding.jax(),
132130
)
133131

134132

135133
# pylint: disable-next=all
136134
def KVCacheGenerate_unflatten(auxdata, data):
137135
position, sharding = auxdata
138-
cache_k, cache_v = torch_xla2.tensor.wrap(data)
136+
cache_k, cache_v = torchjax.to_torch(data)
139137
cache = KVCacheGenerate(cache_k, cache_v, position, sharding)
140138
return cache
141139

@@ -168,11 +166,11 @@ def __init__(
168166

169167
def state(self):
170168
"""Get kv cache state"""
171-
return torch_xla2.tensor.unwrap((self.cache_k, self.cache_v))
169+
return torchjax.from_torch((self.cache_k, self.cache_v))
172170

173171
def scalers(self):
174172
"""Get kv cache scalers"""
175-
return torch_xla2.tensor.unwrap((self.k_scaler, self.v_scaler))
173+
return torchjax.from_torch((self.k_scaler, self.v_scaler))
176174

177175
@classmethod
178176
# pylint: disable-next=all
@@ -184,7 +182,7 @@ def empty(cls, shape, device, bf16_enable):
184182
kscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16)
185183
vscaler = jnp.ones((shape[0], 1, shape[2], 1), dtype=jnp.bfloat16)
186184

187-
cache_k, cache_v, kscaler, vscaler = torch_xla2.tensor.wrap(
185+
cache_k, cache_v, kscaler, vscaler = torchjax.to_torch(
188186
(cache_k, cache_v, kscaler, vscaler)
189187
)
190188
return cls(cache_k, cache_v, kscaler, vscaler, 0, device)

jetstream_pt/engine.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from jetstream_pt import cache_manager
3535
from jetstream_pt import quantize
36+
from jetstream_pt import torchjax
3637
from jetstream_pt.environment import JetEngineEnvironment, JetEngineEnvironmentData
3738
from jetstream_pt.third_party.llama import model_exportable, model_args
3839
from jetstream_pt.third_party.gemma import config as gemma_config, model as gemma_model
@@ -86,8 +87,11 @@ def __init__(
8687
self.y_sharding = env.sharding_by_axis(1)
8788
self.x_sharding = env.sharding_by_axis(0)
8889
self.replicated = env.sharding_by_axis(-1) # replicated
90+
8991
self.cache_sharding = self.env.cache_sharding
9092

93+
jax.config.update("jax_enable_x64", False)
94+
9195
self.prefill = jax.jit(
9296
self.prefill, out_shardings=self.get_prefix_destination_sharding()
9397
)
@@ -147,7 +151,7 @@ def _call_model_generate(
147151
if self.env.enable_kv_quantization:
148152
caches_obj = [
149153
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
150-
for (k, v), (ks, vs) in torch_xla2.tensor.wrap(
154+
for (k, v), (ks, vs) in torchjax.to_torch(
151155
list(zip(caches, cache_scales))
152156
)
153157
]
@@ -156,20 +160,22 @@ def _call_model_generate(
156160
cache_manager.KVCacheGenerate(
157161
k, v, input_indexes, self.cache_sharding
158162
)
159-
for k, v in torch_xla2.tensor.wrap(caches)
163+
for k, v in torchjax.to_torch(caches)
160164
]
161165
mask = jnp.expand_dims(mask, (1, 2))
162166

163167
args = (tokens, input_pos, caches_obj, mask)
164-
paramst, argst = torch_xla2.tensor.wrap((weights, args))
168+
paramst, argst = torchjax.to_torch((weights, args))
165169
with self._lock:
166-
with torch_xla2.tensor.XLADispatchMode():
170+
with torchjax.jax_mode:
171+
# The mode is needed so that tensors created inside of
172+
# the model (such as via torch.ones etc) also have the right type
167173
res = torch.func.functional_call(self.pt_model, paramst, argst)
168174
updated_caches = [c.state() for c in caches_obj]
169175
scales = []
170176
if self.env.enable_kv_quantization:
171177
scales = [c.scalers() for c in caches_obj]
172-
return torch_xla2.tensor.unwrap((res, updated_caches, scales))
178+
return torchjax.from_torch((res, updated_caches, scales))
173179

174180
@functools.partial(
175181
jax.jit,
@@ -188,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
188194
mask = jnp.triu(mask, k=1)
189195
args = (tokens, input_indexes, caches, mask)
190196

191-
paramst, argst = torch_xla2.tensor.wrap((weights, args))
197+
paramst, argst = torchjax.to_torch((weights, args))
192198
with self._lock:
193-
with torch_xla2.tensor.XLADispatchMode():
199+
with torchjax.jax_mode:
194200
res = torch.func.functional_call(self.pt_model, paramst, argst)[0]
195201
caches_res = [c.state() for c in caches]
196-
return torch_xla2.tensor.unwrap((res, caches_res))
202+
return torchjax.from_torch((res, caches_res))
197203

198204
def _sampling(self, logits: Any, batch_size: int) -> jnp.ndarray:
199205
if len(logits.shape) == 2:
@@ -287,20 +293,20 @@ def insert(cache, new_entry):
287293
@functools.partial(jax.jit, donate_argnums=(0, 1), inline=True)
288294
def insert(cache, scaler, new_entry):
289295
reduce_axis = (1, 3)
290-
vals, scales = torch_xla2.extra.call_torch(
296+
vals, scales = torch_xla2.interop.call_torch(
291297
quantize.quantize_torch_int8, new_entry, reduce_axis
292298
)
293299
new_scaler = jax.lax.dynamic_update_slice(
294300
scaler,
295-
scales,
301+
scales.jax(),
296302
[slot, 0, pos, 0],
297303
)
298304
new_scaler = jax.lax.with_sharding_constraint(
299305
new_scaler, self.replicated
300306
)
301307
res = jax.lax.dynamic_update_slice(
302308
cache,
303-
vals,
309+
vals.jax(),
304310
[slot, 0, pos, 0],
305311
)
306312
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)
@@ -386,7 +392,7 @@ def insert(cache, new_entry):
386392
def insert(cache, scaler, new_entry):
387393
new_entry = jnp.transpose(new_entry.squeeze(0), (1, 0, 2))
388394
reduce_axis = (1, 2)
389-
vals, scales = torch_xla2.extra.call_torch(
395+
vals, scales = torch_xla2.interop.call_torch(
390396
quantize.quantize_torch_int8, new_entry, reduce_axis
391397
)
392398
new_scaler = scaler.at[slot, :, update_indexes, :].set(scales)
@@ -559,7 +565,7 @@ def _load_from_state_dict(self, path):
559565
for key, model_weights in self.pt_model.state_dict().items():
560566
assert key in state_dict, f"key: {key} not found"
561567
arr = jax.device_put(
562-
torch_xla2.tensor.t2j(state_dict[key]), self.env.sharding_by_name(key)
568+
torchjax.from_torch(state_dict[key]), self.env.sharding_by_name(key)
563569
)
564570
assert tuple(model_weights.shape) == tuple(
565571
arr.shape
@@ -602,14 +608,14 @@ def get_prefix_destination_sharding(self) -> Prefix:
602608
"""Returns the shardings necessary to transfer data between engines."""
603609
return Prefix(
604610
self.replicated,
605-
self.cache_sharding,
611+
self.replicated if self.env.shard_on_batch else self.cache_sharding,
606612
self.replicated,
607613
)
608614

609615
def get_decode_state_sharding(self) -> DecodeState:
610616
"""Gets the shardings corresponding to the decode state."""
611617
return DecodeState(
612-
self.replicated,
618+
self.x_sharding if self.env.shard_on_batch else self.replicated,
613619
self.cache_sharding,
614620
self.replicated,
615621
self.replicated,
@@ -663,6 +669,7 @@ def create_pytorch_engine(
663669
quantize_kv=False,
664670
max_cache_length=1024,
665671
sharding_config=None,
672+
shard_on_batch=False,
666673
) -> PyTorchEngine:
667674
"""Returns: The pytorch engine."""
668675

@@ -718,8 +725,12 @@ def create_pytorch_engine(
718725
cache_sequence_length=max_cache_length,
719726
bf16_enable=bf16_enable,
720727
sharding_config_path=sharding_config,
728+
shard_on_batch=shard_on_batch,
721729
)
722730

731+
if shard_on_batch and sharding_config:
732+
print("WARNING: with sharding_on_batch sharding config is ignored.")
733+
723734
if model_name.startswith("llama"):
724735

725736
args = model_args.get_model_args(

jetstream_pt/environment.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ class JetEngineEnvironmentData:
7575

7676
sharding_config_path: str = ""
7777

78+
# Whether to shard on batch dimension. i.e. data parallel.
79+
shard_on_batch: bool = False
80+
7881

7982
# pylint: disable-next=all
8083
class JetEngineEnvironment:
@@ -97,9 +100,12 @@ def __init__(self, data: JetEngineEnvironmentData):
97100
self.x_sharding = jsharding.NamedSharding(self._mesh, P("x"))
98101
self.replicated = jsharding.NamedSharding(self._mesh, P())
99102

100-
cache_sharding_axis = self.attention_kv_axis_names.index(
101-
self.kv_cache_shard_axis
102-
)
103+
if data.shard_on_batch:
104+
cache_sharding_axis = 0
105+
else:
106+
cache_sharding_axis = self.attention_kv_axis_names.index(
107+
self.kv_cache_shard_axis
108+
)
103109

104110
if self.cache_shape[cache_sharding_axis] == 1:
105111
# cannot shard on an axis that is 1
@@ -169,6 +175,9 @@ def make_caches_generate(self):
169175

170176
def sharding_by_name(self, name):
171177
"""Create sharding specified in the config."""
178+
if self.shard_on_batch:
179+
return self.shading_by_axis(0) # batch dimension
180+
172181
if name in self._sharding_config:
173182
return self.sharding_by_axis(self._sharding_config[name])
174183

0 commit comments

Comments
 (0)