@@ -255,7 +255,6 @@ def export(self, export_dir: Optional[str] = None) -> str:
255
255
fbs if self .continuous_batching else bs , self .model .config .vocab_size , dtype = torch .int32 )
256
256
dynamic_axes ["repetition_penalty_retain_state" ] = {
257
257
0 : "full_batch_size" if self .continuous_batching else "batch_size" ,
258
- 1 : "vocab_size" ,
259
258
}
260
259
output_names .append ("repetition_penalty_retain_state_RetainedState" )
261
260
@@ -266,7 +265,6 @@ def export(self, export_dir: Optional[str] = None) -> str:
266
265
fbs if self .continuous_batching else bs , self .model .config .vocab_size , dtype = torch .int32 )
267
266
dynamic_axes ["presence_penalty_retain_state" ] = {
268
267
0 : "full_batch_size" if self .continuous_batching else "batch_size" ,
269
- 1 : "vocab_size" ,
270
268
}
271
269
output_names .append ("presence_penalty_retain_state_RetainedState" )
272
270
@@ -374,7 +372,6 @@ def compile(
374
372
}
375
373
if self .include_sampler :
376
374
prefill_specialization .update ({
377
- "vocab_size" : self .model .config .vocab_size ,
378
375
"max_top_k_ids" : constants .Constants .MAX_TOP_K_IDS ,
379
376
})
380
377
prefill_specialization .update ({"num_logits_to_keep" : 1 })
@@ -396,7 +393,6 @@ def compile(
396
393
}
397
394
if self .include_sampler :
398
395
decode_specialization .update ({
399
- "vocab_size" : self .model .config .vocab_size ,
400
396
"max_top_k_ids" : constants .Constants .MAX_TOP_K_IDS ,
401
397
})
402
398
if self .continuous_batching :
0 commit comments