Skip to content

Commit 0f4e7ae

Browse files
authored
Update version of jetstream; misc fixes (#88)
misc fixes
1 parent a371a5e commit 0f4e7ae

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

convert_checkpoints.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,12 @@ def _get_llama_state_dict(input_ckpt_dir):
324324
print(f"Loading checkpoints takes {end - start} seconds")
325325

326326
start = time.perf_counter()
327-
state_dict = _merge_llama_weights(
328-
checkpoints, _MINIMIZE_MEMORY_FOOTPRINT.value, _ENABLE_FLOAT32.value
329-
)
327+
if len(checkpoints) > 1:
328+
state_dict = _merge_llama_weights(
329+
checkpoints, _MINIMIZE_MEMORY_FOOTPRINT.value, _ENABLE_FLOAT32.value
330+
)
331+
else:
332+
state_dict = checkpoints[0]
330333
end = time.perf_counter()
331334
print(f"Merging weights takes {end - start} seconds")
332335
return state_dict, params

install_everything.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
TORCHXLA_TAG=f26c35c2fa5eb1d22d042a2a8a8dc34f11b99f60 # updated May 14, 2024
16-
JETSTREAM_TAG=v0.2.1
16+
JETSTREAM_TAG=e4952fbb12e0ab3c33bc7c1eef3839b7c2ad0dd4 # updated May 16, 2024
1717

1818
# Uninstall existing jax
1919
pip show jax && pip uninstall -y jax
@@ -39,4 +39,4 @@ git checkout $JETSTREAM_TAG
3939
pip install .
4040
popd # now at the folder deps
4141
popd # now at the folder current file
42-
pip install -e .
42+
pip install -e .

jetstream_pt/engine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,15 +298,15 @@ def insert(cache, scaler, new_entry):
298298
)
299299
new_scaler = jax.lax.dynamic_update_slice(
300300
scaler,
301-
scales.jax(),
301+
scales,
302302
[slot, 0, pos, 0],
303303
)
304304
new_scaler = jax.lax.with_sharding_constraint(
305305
new_scaler, self.replicated
306306
)
307307
res = jax.lax.dynamic_update_slice(
308308
cache,
309-
vals.jax(),
309+
vals,
310310
[slot, 0, pos, 0],
311311
)
312312
res = jax.lax.with_sharding_constraint(res, self.cache_sharding)

jetstream_pt/third_party/llama/model_exportable.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,9 @@ def forward(
188188
freqs_cis = self.freqs_cis[input_pos]
189189
freqs_cis = freqs_cis.reshape(bsz, seqlen, -1)
190190

191+
assert len(caches) == len(
192+
self.layers
193+
), f"Number of caches ({len(caches)}) and layers ({len(self.layers)}) dont match"
191194
for layer, cache in zip(self.layers, caches):
192195
with jax.named_scope("TransformerBlock"):
193196
h = layer(h, freqs_cis, mask, cache)

run_interactive.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def main(argv):
140140
]
141141
for prompt in prompts:
142142
slot = random.randint(0, _BATCH_SIZE.value - 1)
143-
tokens, true_length = tokenizer.encode(prompt, is_bos=True)
143+
tokens, true_length = tokenizer.encode(prompt)
144144

145145
print(f"---- Input prompts are: {prompt}")
146146
print(f"---- Encoded tokens are: {tokens}")
@@ -157,12 +157,15 @@ def main(argv):
157157
while True:
158158
decode_state, result_tokens = engine.generate(params, decode_state)
159159
result_tokens = result_tokens.convert_to_numpy()
160-
output, complete = tokenizer.decode(
161-
slot, max_output_length, result_tokens, complete
162-
)
163-
if complete[0]:
160+
res = result_tokens.get_result_at_slot(slot)
161+
stop_tokens = set(tokenizer.tokenizer.stop_tokens)
162+
stop_tokens.add(tokenizer.pad_id)
163+
if (
164+
res.tokens[0][0] in stop_tokens
165+
or len(sampled_tokens_list) > max_output_length
166+
):
164167
break
165-
token_id = output[0][0]
168+
token_id = res.tokens[0][0]
166169
sampled_tokens_list.append(token_id)
167170
# output_str = tokenizer.decode_str([token_id])
168171
# print(Fore.GREEN + output_str, end="", flush=True)
@@ -173,7 +176,7 @@ def main(argv):
173176
print("---- All output tokens.")
174177
print(sampled_tokens_list)
175178
print("---- All output text.")
176-
print(tokenizer.decode_str(sampled_tokens_list))
179+
print(tokenizer.decode(sampled_tokens_list))
177180

178181
if _PROFILING_OUTPUT.value:
179182
jax.profiler.stop_trace()

0 commit comments

Comments
 (0)