@@ -1098,6 +1098,102 @@ def check_indicator_and_length(
1098
1098
# Google Gemma 3
1099
1099
##################
1100
1100
gemma3 = [
1101
+ # https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json
1102
+ dict (
1103
+ name = "Gemma-3-1b-it" ,
1104
+ hf_config = dict (org = "google" , name = "gemma-3-1b-it" ),
1105
+ scale_embeddings = True ,
1106
+ attention_scores_scalar = 256 ,
1107
+ vocab_size = 262144 ,
1108
+ block_size = 131072 ,
1109
+ sliding_window_size = 512 ,
1110
+ # 5 local layers for every global layer
1111
+ sliding_window_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (26 )],
1112
+ intermediate_size = 21504 ,
1113
+ n_embd = 1152 ,
1114
+ n_layer = 26 ,
1115
+ n_head = 4 ,
1116
+ n_query_groups = 1 ,
1117
+ head_size = 256 ,
1118
+ rotary_percentage = 1.0 ,
1119
+ rope_adjustments = None ,
1120
+ parallel_residual = False ,
1121
+ bias = False ,
1122
+ norm_class_name = "RMSNorm" ,
1123
+ mlp_class_name = "GemmaMLP" ,
1124
+ gelu_approximate = "tanh" ,
1125
+ post_attention_norm = True ,
1126
+ post_mlp_norm = True ,
1127
+ norm_qk = True ,
1128
+ rope_base = 1000000 ,
1129
+ rope_local_base_freq = 10000 ,
1130
+ # 5 local layers for every global layer
1131
+ rope_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (26 )],
1132
+ ),
1133
+ # https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
1134
+ dict (
1135
+ name = "Gemma-3-4b-it" ,
1136
+ hf_config = dict (org = "google" , name = "gemma-3-4b-it" ),
1137
+ scale_embeddings = True ,
1138
+ attention_scores_scalar = 256 ,
1139
+ vocab_size = 262144 ,
1140
+ block_size = 131072 ,
1141
+ sliding_window_size = 1024 ,
1142
+ # 5 local layers for every global layer
1143
+ sliding_window_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (34 )],
1144
+ intermediate_size = 10240 ,
1145
+ n_embd = 2560 ,
1146
+ n_layer = 34 ,
1147
+ n_head = 8 ,
1148
+ n_query_groups = 4 ,
1149
+ head_size = 256 ,
1150
+ rotary_percentage = 1.0 ,
1151
+ rope_adjustments = dict (factor = 8.0 ),
1152
+ parallel_residual = False ,
1153
+ bias = False ,
1154
+ norm_class_name = "RMSNorm" ,
1155
+ mlp_class_name = "GemmaMLP" ,
1156
+ gelu_approximate = "tanh" ,
1157
+ post_attention_norm = True ,
1158
+ post_mlp_norm = True ,
1159
+ norm_qk = True ,
1160
+ rope_base = 1000000 ,
1161
+ rope_local_base_freq = 10000 ,
1162
+ # 5 local layers for every global layer
1163
+ rope_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (34 )],
1164
+ ),
1165
+ # https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json
1166
+ dict (
1167
+ name = "Gemma-3-12b-it" ,
1168
+ hf_config = dict (org = "google" , name = "gemma-3-12b-it" ),
1169
+ scale_embeddings = True ,
1170
+ attention_scores_scalar = 256 ,
1171
+ vocab_size = 262144 ,
1172
+ block_size = 131072 ,
1173
+ sliding_window_size = 1024 ,
1174
+ # 5 local layers for every global layer
1175
+ sliding_window_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (48 )],
1176
+ intermediate_size = 15360 ,
1177
+ n_embd = 3840 ,
1178
+ n_layer = 48 ,
1179
+ n_head = 16 ,
1180
+ n_query_groups = 8 ,
1181
+ head_size = 256 ,
1182
+ rotary_percentage = 1.0 ,
1183
+ rope_adjustments = dict (factor = 8.0 ),
1184
+ parallel_residual = False ,
1185
+ bias = False ,
1186
+ norm_class_name = "RMSNorm" ,
1187
+ mlp_class_name = "GemmaMLP" ,
1188
+ gelu_approximate = "tanh" ,
1189
+ post_attention_norm = True ,
1190
+ post_mlp_norm = True ,
1191
+ norm_qk = True ,
1192
+ rope_base = 1000000 ,
1193
+ rope_local_base_freq = 10000 ,
1194
+ # 5 local layers for every global layer
1195
+ rope_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (48 )],
1196
+ ),
1101
1197
# https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json
1102
1198
dict (
1103
1199
name = "Gemma-3-27b-it" ,
0 commit comments