33
33
34
34
from jetstream_pt import cache_manager
35
35
from jetstream_pt import quantize
36
+ from jetstream_pt import torchjax
36
37
from jetstream_pt .environment import JetEngineEnvironment , JetEngineEnvironmentData
37
38
from jetstream_pt .third_party .llama import model_exportable , model_args
38
39
from jetstream_pt .third_party .gemma import config as gemma_config , model as gemma_model
@@ -86,8 +87,11 @@ def __init__(
86
87
self .y_sharding = env .sharding_by_axis (1 )
87
88
self .x_sharding = env .sharding_by_axis (0 )
88
89
self .replicated = env .sharding_by_axis (- 1 ) # replicated
90
+
89
91
self .cache_sharding = self .env .cache_sharding
90
92
93
+ jax .config .update ("jax_enable_x64" , False )
94
+
91
95
self .prefill = jax .jit (
92
96
self .prefill , out_shardings = self .get_prefix_destination_sharding ()
93
97
)
@@ -147,7 +151,7 @@ def _call_model_generate(
147
151
if self .env .enable_kv_quantization :
148
152
caches_obj = [
149
153
cache_manager .Int8KVCacheGenerate (k , v , ks , vs , input_indexes )
150
- for (k , v ), (ks , vs ) in torch_xla2 . tensor . wrap (
154
+ for (k , v ), (ks , vs ) in torchjax . to_torch (
151
155
list (zip (caches , cache_scales ))
152
156
)
153
157
]
@@ -156,20 +160,22 @@ def _call_model_generate(
156
160
cache_manager .KVCacheGenerate (
157
161
k , v , input_indexes , self .cache_sharding
158
162
)
159
- for k , v in torch_xla2 . tensor . wrap (caches )
163
+ for k , v in torchjax . to_torch (caches )
160
164
]
161
165
mask = jnp .expand_dims (mask , (1 , 2 ))
162
166
163
167
args = (tokens , input_pos , caches_obj , mask )
164
- paramst , argst = torch_xla2 . tensor . wrap ((weights , args ))
168
+ paramst , argst = torchjax . to_torch ((weights , args ))
165
169
with self ._lock :
166
- with torch_xla2 .tensor .XLADispatchMode ():
170
+ with torchjax .jax_mode :
171
+ # The mode is needed so that tensors created inside of
172
+ # the model (such as via torch.ones etc) also have the right type
167
173
res = torch .func .functional_call (self .pt_model , paramst , argst )
168
174
updated_caches = [c .state () for c in caches_obj ]
169
175
scales = []
170
176
if self .env .enable_kv_quantization :
171
177
scales = [c .scalers () for c in caches_obj ]
172
- return torch_xla2 . tensor . unwrap ((res , updated_caches , scales ))
178
+ return torchjax . from_torch ((res , updated_caches , scales ))
173
179
174
180
@functools .partial (
175
181
jax .jit ,
@@ -188,12 +194,12 @@ def _call_model_prefill(self, weights, tokens, input_indexes):
188
194
mask = jnp .triu (mask , k = 1 )
189
195
args = (tokens , input_indexes , caches , mask )
190
196
191
- paramst , argst = torch_xla2 . tensor . wrap ((weights , args ))
197
+ paramst , argst = torchjax . to_torch ((weights , args ))
192
198
with self ._lock :
193
- with torch_xla2 . tensor . XLADispatchMode () :
199
+ with torchjax . jax_mode :
194
200
res = torch .func .functional_call (self .pt_model , paramst , argst )[0 ]
195
201
caches_res = [c .state () for c in caches ]
196
- return torch_xla2 . tensor . unwrap ((res , caches_res ))
202
+ return torchjax . from_torch ((res , caches_res ))
197
203
198
204
def _sampling (self , logits : Any , batch_size : int ) -> jnp .ndarray :
199
205
if len (logits .shape ) == 2 :
@@ -287,20 +293,20 @@ def insert(cache, new_entry):
287
293
@functools .partial (jax .jit , donate_argnums = (0 , 1 ), inline = True )
288
294
def insert (cache , scaler , new_entry ):
289
295
reduce_axis = (1 , 3 )
290
- vals , scales = torch_xla2 .extra .call_torch (
296
+ vals , scales = torch_xla2 .interop .call_torch (
291
297
quantize .quantize_torch_int8 , new_entry , reduce_axis
292
298
)
293
299
new_scaler = jax .lax .dynamic_update_slice (
294
300
scaler ,
295
- scales ,
301
+ scales . jax () ,
296
302
[slot , 0 , pos , 0 ],
297
303
)
298
304
new_scaler = jax .lax .with_sharding_constraint (
299
305
new_scaler , self .replicated
300
306
)
301
307
res = jax .lax .dynamic_update_slice (
302
308
cache ,
303
- vals ,
309
+ vals . jax () ,
304
310
[slot , 0 , pos , 0 ],
305
311
)
306
312
res = jax .lax .with_sharding_constraint (res , self .cache_sharding )
@@ -386,7 +392,7 @@ def insert(cache, new_entry):
386
392
def insert (cache , scaler , new_entry ):
387
393
new_entry = jnp .transpose (new_entry .squeeze (0 ), (1 , 0 , 2 ))
388
394
reduce_axis = (1 , 2 )
389
- vals , scales = torch_xla2 .extra .call_torch (
395
+ vals , scales = torch_xla2 .interop .call_torch (
390
396
quantize .quantize_torch_int8 , new_entry , reduce_axis
391
397
)
392
398
new_scaler = scaler .at [slot , :, update_indexes , :].set (scales )
@@ -559,7 +565,7 @@ def _load_from_state_dict(self, path):
559
565
for key , model_weights in self .pt_model .state_dict ().items ():
560
566
assert key in state_dict , f"key: { key } not found"
561
567
arr = jax .device_put (
562
- torch_xla2 . tensor . t2j (state_dict [key ]), self .env .sharding_by_name (key )
568
+ torchjax . from_torch (state_dict [key ]), self .env .sharding_by_name (key )
563
569
)
564
570
assert tuple (model_weights .shape ) == tuple (
565
571
arr .shape
@@ -602,14 +608,14 @@ def get_prefix_destination_sharding(self) -> Prefix:
602
608
"""Returns the shardings necessary to transfer data between engines."""
603
609
return Prefix (
604
610
self .replicated ,
605
- self .cache_sharding ,
611
+ self .replicated if self . env . shard_on_batch else self . cache_sharding ,
606
612
self .replicated ,
607
613
)
608
614
609
615
def get_decode_state_sharding (self ) -> DecodeState :
610
616
"""Gets the shardings corresponding to the decode state."""
611
617
return DecodeState (
612
- self .replicated ,
618
+ self .x_sharding if self . env . shard_on_batch else self . replicated ,
613
619
self .cache_sharding ,
614
620
self .replicated ,
615
621
self .replicated ,
@@ -663,6 +669,7 @@ def create_pytorch_engine(
663
669
quantize_kv = False ,
664
670
max_cache_length = 1024 ,
665
671
sharding_config = None ,
672
+ shard_on_batch = False ,
666
673
) -> PyTorchEngine :
667
674
"""Returns: The pytorch engine."""
668
675
@@ -718,8 +725,12 @@ def create_pytorch_engine(
718
725
cache_sequence_length = max_cache_length ,
719
726
bf16_enable = bf16_enable ,
720
727
sharding_config_path = sharding_config ,
728
+ shard_on_batch = shard_on_batch ,
721
729
)
722
730
731
+ if shard_on_batch and sharding_config :
732
+ print ("WARNING: with sharding_on_batch sharding config is ignored." )
733
+
723
734
if model_name .startswith ("llama" ):
724
735
725
736
args = model_args .get_model_args (
0 commit comments