31
31
layernorm_position_map = {i .name : i .value for i in LayerNormPositionType }
32
32
mlp_type_map = {i .name : i .value for i in MLPType }
33
33
34
+ # Constants for specific model configurations
35
+ ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS = 20000
36
+
34
37
35
38
def copy_args_to_component_config (component_config , args ):
36
39
for arg in vars (args ):
@@ -773,6 +776,10 @@ def parse_bart_config_by_component(config, component, args):
773
776
encoder_config = parse_bart_config_by_component (config , "encoder" , args )
774
777
decoder_config = parse_bart_config_by_component (config , "decoder" , args )
775
778
779
+ # Override n_positions for eclair_radio model
780
+ if args .eclair_radio :
781
+ decoder_config .n_positions = ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS
782
+
776
783
return encoder_config , decoder_config
777
784
778
785
@@ -1599,7 +1606,7 @@ def get_processor():
1599
1606
"model.safetensors" ),
1600
1607
strict = False )
1601
1608
model .decoder .model .decoder .embed_positions = MBartLearnedPositionalEmbedding (
1602
- 20_000 , d_model )
1609
+ ECLAIR_RADIO_MAX_POSITION_EMBEDDINGS , d_model )
1603
1610
model .decoder .model .decoder .embed_positions .weight .data .zero_ ()
1604
1611
model .decoder .model .decoder .embed_positions .weight .requires_grad_ (
1605
1612
True )
@@ -1717,26 +1724,16 @@ def convert_checkpoint(args):
1717
1724
encoder_convert_args = dict (params = model .state_dict (),
1718
1725
component = "encoder" )
1719
1726
tllm_decoder_config = {
1720
- 'architecture' :
1721
- "DecoderModel" ,
1722
- 'dtype' :
1723
- args .dtype ,
1724
- 'logits_dtype' :
1725
- decoder_config .logits_dtype ,
1726
- 'num_hidden_layers' :
1727
- decoder_config .n_layer ,
1728
- 'num_attention_heads' :
1729
- decoder_config .n_head ,
1730
- 'hidden_size' :
1731
- decoder_config .hidden_size ,
1732
- 'norm_epsilon' :
1733
- decoder_config .layernorm_eps ,
1734
- 'vocab_size' :
1735
- decoder_config .vocab_size ,
1736
- 'position_embedding_type' :
1737
- decoder_config .position_embedding_type ,
1738
- 'hidden_act' :
1739
- decoder_config .hidden_act ,
1727
+ 'architecture' : "DecoderModel" ,
1728
+ 'dtype' : args .dtype ,
1729
+ 'logits_dtype' : decoder_config .logits_dtype ,
1730
+ 'num_hidden_layers' : decoder_config .n_layer ,
1731
+ 'num_attention_heads' : decoder_config .n_head ,
1732
+ 'hidden_size' : decoder_config .hidden_size ,
1733
+ 'norm_epsilon' : decoder_config .layernorm_eps ,
1734
+ 'vocab_size' : decoder_config .vocab_size ,
1735
+ 'position_embedding_type' : decoder_config .position_embedding_type ,
1736
+ 'hidden_act' : decoder_config .hidden_act ,
1740
1737
'quantization' : {
1741
1738
'quant_algo' : quant_algo ,
1742
1739
'kv_cache_quant_algo' : kv_cache_quant_algo ,
@@ -1746,64 +1743,35 @@ def convert_checkpoint(args):
1746
1743
'tp_size' : args .tp_size ,
1747
1744
'pp_size' : args .pp_size ,
1748
1745
},
1749
- 'use_parallel_embedding' :
1750
- args .use_parallel_embedding ,
1751
- 'embedding_sharding_dim' :
1752
- args .embedding_sharding_dim ,
1753
- 'max_position_embeddings' :
1754
- decoder_config .n_positions if not args .eclair_radio else 20000 ,
1755
- 'head_size' :
1756
- decoder_config .head_size ,
1757
- 'has_position_embedding' :
1758
- decoder_config .has_position_embedding ,
1759
- 'layernorm_type' :
1760
- decoder_config .layernorm_type ,
1761
- 'has_attention_qkvo_bias' :
1762
- decoder_config .has_attention_qkvo_bias ,
1763
- 'has_mlp_bias' :
1764
- decoder_config .has_mlp_bias ,
1765
- 'has_model_final_layernorm' :
1766
- decoder_config .has_model_final_layernorm ,
1767
- 'has_embedding_layernorm' :
1768
- decoder_config .has_embedding_layernorm ,
1769
- 'has_embedding_scale' :
1770
- decoder_config .has_embedding_scale ,
1771
- 'intermediate_size' :
1772
- decoder_config .ffn_hidden_size ,
1773
- 'q_scaling' :
1774
- decoder_config .q_scaling ,
1775
- 'layernorm_position' :
1776
- decoder_config .layernorm_position ,
1777
- 'mlp_type' :
1778
- decoder_config .mlp_type ,
1779
- 'relative_attention' :
1780
- decoder_config .relative_attention ,
1781
- 'max_distance' :
1782
- decoder_config .max_distance ,
1783
- 'num_buckets' :
1784
- decoder_config .num_buckets ,
1785
- 'model_type' :
1786
- decoder_config .model_type ,
1787
- 'rescale_before_lm_head' :
1788
- decoder_config .rescale_before_lm_head ,
1789
- 'encoder_hidden_size' :
1790
- decoder_config .encoder_hidden_size ,
1791
- 'encoder_num_heads' :
1792
- decoder_config .encoder_num_heads ,
1793
- 'encoder_head_size' :
1794
- decoder_config .encoder_head_size ,
1795
- 'skip_cross_kv' :
1796
- args .skip_cross_kv ,
1797
- 'use_implicit_relative_attention' :
1798
- args .use_implicit_relative_attention ,
1799
- 'decoder_start_token_id' :
1800
- decoder_config .decoder_start_token_id ,
1801
- 'eos_token_id' :
1802
- decoder_config .eos_token_id ,
1803
- 'bos_token_id' :
1804
- decoder_config .bos_token_id ,
1805
- 'pad_token_id' :
1806
- decoder_config .pad_token_id ,
1746
+ 'use_parallel_embedding' : args .use_parallel_embedding ,
1747
+ 'embedding_sharding_dim' : args .embedding_sharding_dim ,
1748
+ 'max_position_embeddings' : decoder_config .n_positions ,
1749
+ 'head_size' : decoder_config .head_size ,
1750
+ 'has_position_embedding' : decoder_config .has_position_embedding ,
1751
+ 'layernorm_type' : decoder_config .layernorm_type ,
1752
+ 'has_attention_qkvo_bias' : decoder_config .has_attention_qkvo_bias ,
1753
+ 'has_mlp_bias' : decoder_config .has_mlp_bias ,
1754
+ 'has_model_final_layernorm' : decoder_config .has_model_final_layernorm ,
1755
+ 'has_embedding_layernorm' : decoder_config .has_embedding_layernorm ,
1756
+ 'has_embedding_scale' : decoder_config .has_embedding_scale ,
1757
+ 'intermediate_size' : decoder_config .ffn_hidden_size ,
1758
+ 'q_scaling' : decoder_config .q_scaling ,
1759
+ 'layernorm_position' : decoder_config .layernorm_position ,
1760
+ 'mlp_type' : decoder_config .mlp_type ,
1761
+ 'relative_attention' : decoder_config .relative_attention ,
1762
+ 'max_distance' : decoder_config .max_distance ,
1763
+ 'num_buckets' : decoder_config .num_buckets ,
1764
+ 'model_type' : decoder_config .model_type ,
1765
+ 'rescale_before_lm_head' : decoder_config .rescale_before_lm_head ,
1766
+ 'encoder_hidden_size' : decoder_config .encoder_hidden_size ,
1767
+ 'encoder_num_heads' : decoder_config .encoder_num_heads ,
1768
+ 'encoder_head_size' : decoder_config .encoder_head_size ,
1769
+ 'skip_cross_kv' : args .skip_cross_kv ,
1770
+ 'use_implicit_relative_attention' : args .use_implicit_relative_attention ,
1771
+ 'decoder_start_token_id' : decoder_config .decoder_start_token_id ,
1772
+ 'eos_token_id' : decoder_config .eos_token_id ,
1773
+ 'bos_token_id' : decoder_config .bos_token_id ,
1774
+ 'pad_token_id' : decoder_config .pad_token_id ,
1807
1775
}
1808
1776
for additional_setting in additional_settings :
1809
1777
if hasattr (decoder_config , additional_setting ):
0 commit comments