Skip to content

Commit 0b41c60

Browse files
committed
Reverted unnecessary changes
Signed-off-by: Asmita Goswami <[email protected]>
1 parent c615eeb commit 0b41c60

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

tests/transformers/spd/test_pld_inference.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -145,9 +145,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int):
145145
"""
146146
num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float
147147
input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len
148-
assert input_len_padded <= ctx_len, (
149-
"input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
150-
)
148+
assert (
149+
input_len_padded <= ctx_len
150+
), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
151151
return input_len_padded
152152

153153

tests/transformers/spd/test_spd_inference.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def get_padded_input_len(input_len: int, prefill_seq_len: int, ctx_len: int):
7575
"""
7676
num_chunks = -(input_len // -prefill_seq_len) # ceil divide without float
7777
input_len_padded = num_chunks * prefill_seq_len # Convert input_len to a multiple of prefill_seq_len
78-
assert input_len_padded <= ctx_len, (
79-
"input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
80-
)
78+
assert (
79+
input_len_padded <= ctx_len
80+
), "input_len rounded to nearest prefill_seq_len multiple should be less than ctx_len"
8181
return input_len_padded
8282

8383

@@ -320,9 +320,9 @@ def test_spec_decode_inference(
320320
for prompt, generation in zip(prompts, batch_decode):
321321
print(f"{prompt=} {generation=}")
322322
# validation check
323-
assert mean_num_accepted_tokens == float(num_speculative_tokens + 1), (
324-
f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}"
325-
)
323+
assert mean_num_accepted_tokens == float(
324+
num_speculative_tokens + 1
325+
), f"mean number of accepted tokens is {mean_num_accepted_tokens} but should be {num_speculative_tokens + 1}"
326326
del target_model_session
327327
del draft_model_session
328328
generated_ids = np.asarray(generated_ids[0]).flatten()

0 commit comments

Comments
 (0)