Skip to content

Commit b33197f

Browse files
authored
fix: bs>1 bug for peft models (#164)
Signed-off-by: Ilango Rajagopal <[email protected]>
1 parent 205f1d7 commit b33197f

File tree

2 files changed

+11
-9
lines changed

2 files changed

+11
-9
lines changed

QEfficient/peft/auto.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def generate(
542542

543543
# Decode loop
544544
for num_token in range(1, generation_config.max_new_tokens):
545-
if stopping_criteria(torch.from_numpy(inputs["input_ids"]), torch.from_numpy(outputs["logits"])):
545+
if stopping_criteria(torch.from_numpy(inputs["input_ids"]), torch.from_numpy(outputs["logits"])).all():
546546
break
547547

548548
outputs = self.qpc_session.run(inputs)

tests/peft/test_peft_model.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -158,29 +158,31 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con
158158
qeff_model.set_adapter("invalid")
159159

160160

161+
@pytest.mark.parametrize("batch_size", [1, 4], ids=["bs1", "bs4"])
161162
@pytest.mark.parametrize("base_config,adapter_config", configs)
162-
def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, tmp_path):
163+
def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path):
163164
_, lora_model = create_peft_model(base_config, adapter_config)
164165
qeff_model = QEffAutoPeftModelForCausalLM(lora_model)
165166
qeff_model.export(tmp_path)
166167
start = perf_counter()
167-
qeff_model.compile(prefill_seq_len=32, ctx_len=128)
168+
qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
168169
end = perf_counter()
169170
compile_time_0 = end - start
170171

171172
qeff_model.generate(
172-
input_ids=np.zeros((1, 32), dtype="int64"),
173+
input_ids=np.zeros((batch_size, 32), dtype="int64"),
173174
attention_mask=np.concatenate(
174175
[
175-
np.ones(10, dtype="int64"),
176-
np.zeros(22, dtype="int64"),
177-
]
178-
).reshape(1, 32),
176+
np.ones((batch_size, 10), dtype="int64"),
177+
np.zeros((batch_size, 22), dtype="int64"),
178+
],
179+
axis=1,
180+
),
179181
max_new_tokens=10,
180182
)
181183

182184
start = perf_counter()
183-
qeff_model.compile(prefill_seq_len=32, ctx_len=128)
185+
qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128)
184186
end = perf_counter()
185187
compile_time_1 = end - start
186188
assert compile_time_1 < 0.01 * compile_time_0

0 commit comments

Comments
 (0)