Skip to content

Commit e1e5468

Browse files
aw632yibinl-nvidia
authored andcommitted
chore: clean up 20k reference
Signed-off-by: Andrew Wang <[email protected]>
1 parent 7e7e36b commit e1e5468

File tree

1 file changed

+47
-79
lines changed

1 file changed

+47
-79
lines changed

examples/models/core/enc_dec/convert_checkpoint.py

Lines changed: 47 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
layernorm_position_map = {i.name: i.value for i in LayerNormPositionType}
3232
mlp_type_map = {i.name: i.value for i in MLPType}
3333

34+
# Constants for specific model configurations
35+
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS = 20000
36+
3437

3538
def copy_args_to_component_config(component_config, args):
3639
for arg in vars(args):
@@ -773,6 +776,10 @@ def parse_bart_config_by_component(config, component, args):
773776
encoder_config = parse_bart_config_by_component(config, "encoder", args)
774777
decoder_config = parse_bart_config_by_component(config, "decoder", args)
775778

779+
# Override n_positions for eclair_radio model
780+
if args.eclair_radio:
781+
decoder_config.n_positions = ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS
782+
776783
return encoder_config, decoder_config
777784

778785

@@ -1599,7 +1606,7 @@ def get_processor():
15991606
"model.safetensors"),
16001607
strict=False)
16011608
model.decoder.model.decoder.embed_positions = MBartLearnedPositionalEmbedding(
1602-
20_000, d_model)
1609+
ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS, d_model)
16031610
model.decoder.model.decoder.embed_positions.weight.data.zero_()
16041611
model.decoder.model.decoder.embed_positions.weight.requires_grad_(
16051612
True)
@@ -1717,26 +1724,16 @@ def convert_checkpoint(args):
17171724
encoder_convert_args = dict(params=model.state_dict(),
17181725
component="encoder")
17191726
tllm_decoder_config = {
1720-
'architecture':
1721-
"DecoderModel",
1722-
'dtype':
1723-
args.dtype,
1724-
'logits_dtype':
1725-
decoder_config.logits_dtype,
1726-
'num_hidden_layers':
1727-
decoder_config.n_layer,
1728-
'num_attention_heads':
1729-
decoder_config.n_head,
1730-
'hidden_size':
1731-
decoder_config.hidden_size,
1732-
'norm_epsilon':
1733-
decoder_config.layernorm_eps,
1734-
'vocab_size':
1735-
decoder_config.vocab_size,
1736-
'position_embedding_type':
1737-
decoder_config.position_embedding_type,
1738-
'hidden_act':
1739-
decoder_config.hidden_act,
1727+
'architecture': "DecoderModel",
1728+
'dtype': args.dtype,
1729+
'logits_dtype': decoder_config.logits_dtype,
1730+
'num_hidden_layers': decoder_config.n_layer,
1731+
'num_attention_heads': decoder_config.n_head,
1732+
'hidden_size': decoder_config.hidden_size,
1733+
'norm_epsilon': decoder_config.layernorm_eps,
1734+
'vocab_size': decoder_config.vocab_size,
1735+
'position_embedding_type': decoder_config.position_embedding_type,
1736+
'hidden_act': decoder_config.hidden_act,
17401737
'quantization': {
17411738
'quant_algo': quant_algo,
17421739
'kv_cache_quant_algo': kv_cache_quant_algo,
@@ -1746,64 +1743,35 @@ def convert_checkpoint(args):
17461743
'tp_size': args.tp_size,
17471744
'pp_size': args.pp_size,
17481745
},
1749-
'use_parallel_embedding':
1750-
args.use_parallel_embedding,
1751-
'embedding_sharding_dim':
1752-
args.embedding_sharding_dim,
1753-
'max_position_embeddings':
1754-
decoder_config.n_positions if not args.eclair_radio else 20000,
1755-
'head_size':
1756-
decoder_config.head_size,
1757-
'has_position_embedding':
1758-
decoder_config.has_position_embedding,
1759-
'layernorm_type':
1760-
decoder_config.layernorm_type,
1761-
'has_attention_qkvo_bias':
1762-
decoder_config.has_attention_qkvo_bias,
1763-
'has_mlp_bias':
1764-
decoder_config.has_mlp_bias,
1765-
'has_model_final_layernorm':
1766-
decoder_config.has_model_final_layernorm,
1767-
'has_embedding_layernorm':
1768-
decoder_config.has_embedding_layernorm,
1769-
'has_embedding_scale':
1770-
decoder_config.has_embedding_scale,
1771-
'intermediate_size':
1772-
decoder_config.ffn_hidden_size,
1773-
'q_scaling':
1774-
decoder_config.q_scaling,
1775-
'layernorm_position':
1776-
decoder_config.layernorm_position,
1777-
'mlp_type':
1778-
decoder_config.mlp_type,
1779-
'relative_attention':
1780-
decoder_config.relative_attention,
1781-
'max_distance':
1782-
decoder_config.max_distance,
1783-
'num_buckets':
1784-
decoder_config.num_buckets,
1785-
'model_type':
1786-
decoder_config.model_type,
1787-
'rescale_before_lm_head':
1788-
decoder_config.rescale_before_lm_head,
1789-
'encoder_hidden_size':
1790-
decoder_config.encoder_hidden_size,
1791-
'encoder_num_heads':
1792-
decoder_config.encoder_num_heads,
1793-
'encoder_head_size':
1794-
decoder_config.encoder_head_size,
1795-
'skip_cross_kv':
1796-
args.skip_cross_kv,
1797-
'use_implicit_relative_attention':
1798-
args.use_implicit_relative_attention,
1799-
'decoder_start_token_id':
1800-
decoder_config.decoder_start_token_id,
1801-
'eos_token_id':
1802-
decoder_config.eos_token_id,
1803-
'bos_token_id':
1804-
decoder_config.bos_token_id,
1805-
'pad_token_id':
1806-
decoder_config.pad_token_id,
1746+
'use_parallel_embedding': args.use_parallel_embedding,
1747+
'embedding_sharding_dim': args.embedding_sharding_dim,
1748+
'max_position_embeddings': decoder_config.n_positions,
1749+
'head_size': decoder_config.head_size,
1750+
'has_position_embedding': decoder_config.has_position_embedding,
1751+
'layernorm_type': decoder_config.layernorm_type,
1752+
'has_attention_qkvo_bias': decoder_config.has_attention_qkvo_bias,
1753+
'has_mlp_bias': decoder_config.has_mlp_bias,
1754+
'has_model_final_layernorm': decoder_config.has_model_final_layernorm,
1755+
'has_embedding_layernorm': decoder_config.has_embedding_layernorm,
1756+
'has_embedding_scale': decoder_config.has_embedding_scale,
1757+
'intermediate_size': decoder_config.ffn_hidden_size,
1758+
'q_scaling': decoder_config.q_scaling,
1759+
'layernorm_position': decoder_config.layernorm_position,
1760+
'mlp_type': decoder_config.mlp_type,
1761+
'relative_attention': decoder_config.relative_attention,
1762+
'max_distance': decoder_config.max_distance,
1763+
'num_buckets': decoder_config.num_buckets,
1764+
'model_type': decoder_config.model_type,
1765+
'rescale_before_lm_head': decoder_config.rescale_before_lm_head,
1766+
'encoder_hidden_size': decoder_config.encoder_hidden_size,
1767+
'encoder_num_heads': decoder_config.encoder_num_heads,
1768+
'encoder_head_size': decoder_config.encoder_head_size,
1769+
'skip_cross_kv': args.skip_cross_kv,
1770+
'use_implicit_relative_attention': args.use_implicit_relative_attention,
1771+
'decoder_start_token_id': decoder_config.decoder_start_token_id,
1772+
'eos_token_id': decoder_config.eos_token_id,
1773+
'bos_token_id': decoder_config.bos_token_id,
1774+
'pad_token_id': decoder_config.pad_token_id,
18071775
}
18081776
for additional_setting in additional_settings:
18091777
if hasattr(decoder_config, additional_setting):

0 commit comments

Comments
 (0)