@@ -158,29 +158,31 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con
158
158
qeff_model .set_adapter ("invalid" )
159
159
160
160
161
+ @pytest .mark .parametrize ("batch_size" , [1 , 4 ], ids = ["bs1" , "bs4" ])
161
162
@pytest .mark .parametrize ("base_config,adapter_config" , configs )
162
- def test_auto_peft_model_for_causal_lm_compile_generate (base_config , adapter_config , tmp_path ):
163
+ def test_auto_peft_model_for_causal_lm_compile_generate (base_config , adapter_config , batch_size , tmp_path ):
163
164
_ , lora_model = create_peft_model (base_config , adapter_config )
164
165
qeff_model = QEffAutoPeftModelForCausalLM (lora_model )
165
166
qeff_model .export (tmp_path )
166
167
start = perf_counter ()
167
- qeff_model .compile (prefill_seq_len = 32 , ctx_len = 128 )
168
+ qeff_model .compile (batch_size = batch_size , prefill_seq_len = 32 , ctx_len = 128 )
168
169
end = perf_counter ()
169
170
compile_time_0 = end - start
170
171
171
172
qeff_model .generate (
172
- input_ids = np .zeros ((1 , 32 ), dtype = "int64" ),
173
+ input_ids = np .zeros ((batch_size , 32 ), dtype = "int64" ),
173
174
attention_mask = np .concatenate (
174
175
[
175
- np .ones (10 , dtype = "int64" ),
176
- np .zeros (22 , dtype = "int64" ),
177
- ]
178
- ).reshape (1 , 32 ),
176
+ np .ones ((batch_size , 10 ), dtype = "int64" ),
177
+ np .zeros ((batch_size , 22 ), dtype = "int64" ),
178
+ ],
179
+ axis = 1 ,
180
+ ),
179
181
max_new_tokens = 10 ,
180
182
)
181
183
182
184
start = perf_counter ()
183
- qeff_model .compile (prefill_seq_len = 32 , ctx_len = 128 )
185
+ qeff_model .compile (batch_size = batch_size , prefill_seq_len = 32 , ctx_len = 128 )
184
186
end = perf_counter ()
185
187
compile_time_1 = end - start
186
188
assert compile_time_1 < 0.01 * compile_time_0
0 commit comments