Skip to content

Commit aedd3a9

Browse files
authored
Merge branch 'main' into kvcache3
2 parents 5fb7ba9 + db6b08d commit aedd3a9

File tree

3 files changed

+101
-5
lines changed

3 files changed

+101
-5
lines changed

Diff for: .github/workflows/cpu-tests.yml

+4-4
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ jobs:
8888
continue-on-error: true
8989
with:
9090
path: .cache-HF
91-
key: hf-cache-${{ runner.os }}-${{ matrix.python-version }}
91+
key: hf-cache_${{ runner.os }}-py${{ matrix.python-version }}
9292
restore-keys: |
93-
hf-cache-${{ runner.os }}-${{ matrix.python-version }}
94-
hf-cache-${{ runner.os }}-
95-
hf-cache-
93+
hf-cache_${{ runner.os }}-py${{ matrix.python-version }}
94+
hf-cache_${{ runner.os }}-
95+
hf-cache_
9696
9797
- name: Install dependencies
9898
run: |

Diff for: litgpt/config.py

+96
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,102 @@ def check_indicator_and_length(
10981098
# Google Gemma 3
10991099
##################
11001100
gemma3 = [
1101+
# https://huggingface.co/google/gemma-3-1b-it/blob/main/config.json
1102+
dict(
1103+
name="Gemma-3-1b-it",
1104+
hf_config=dict(org="google", name="gemma-3-1b-it"),
1105+
scale_embeddings=True,
1106+
attention_scores_scalar=256,
1107+
vocab_size=262144,
1108+
block_size=131072,
1109+
sliding_window_size=512,
1110+
# 5 local layers for every global layer
1111+
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
1112+
intermediate_size=21504,
1113+
n_embd=1152,
1114+
n_layer=26,
1115+
n_head=4,
1116+
n_query_groups=1,
1117+
head_size=256,
1118+
rotary_percentage=1.0,
1119+
rope_adjustments=None,
1120+
parallel_residual=False,
1121+
bias=False,
1122+
norm_class_name="RMSNorm",
1123+
mlp_class_name="GemmaMLP",
1124+
gelu_approximate="tanh",
1125+
post_attention_norm=True,
1126+
post_mlp_norm=True,
1127+
norm_qk=True,
1128+
rope_base=1000000,
1129+
rope_local_base_freq=10000,
1130+
# 5 local layers for every global layer
1131+
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(26)],
1132+
),
1133+
# https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
1134+
dict(
1135+
name="Gemma-3-4b-it",
1136+
hf_config=dict(org="google", name="gemma-3-4b-it"),
1137+
scale_embeddings=True,
1138+
attention_scores_scalar=256,
1139+
vocab_size=262144,
1140+
block_size=131072,
1141+
sliding_window_size=1024,
1142+
# 5 local layers for every global layer
1143+
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],
1144+
intermediate_size=10240,
1145+
n_embd=2560,
1146+
n_layer=34,
1147+
n_head=8,
1148+
n_query_groups=4,
1149+
head_size=256,
1150+
rotary_percentage=1.0,
1151+
rope_adjustments=dict(factor=8.0),
1152+
parallel_residual=False,
1153+
bias=False,
1154+
norm_class_name="RMSNorm",
1155+
mlp_class_name="GemmaMLP",
1156+
gelu_approximate="tanh",
1157+
post_attention_norm=True,
1158+
post_mlp_norm=True,
1159+
norm_qk=True,
1160+
rope_base=1000000,
1161+
rope_local_base_freq=10000,
1162+
# 5 local layers for every global layer
1163+
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(34)],
1164+
),
1165+
# https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json
1166+
dict(
1167+
name="Gemma-3-12b-it",
1168+
hf_config=dict(org="google", name="gemma-3-12b-it"),
1169+
scale_embeddings=True,
1170+
attention_scores_scalar=256,
1171+
vocab_size=262144,
1172+
block_size=131072,
1173+
sliding_window_size=1024,
1174+
# 5 local layers for every global layer
1175+
sliding_window_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)],
1176+
intermediate_size=15360,
1177+
n_embd=3840,
1178+
n_layer=48,
1179+
n_head=16,
1180+
n_query_groups=8,
1181+
head_size=256,
1182+
rotary_percentage=1.0,
1183+
rope_adjustments=dict(factor=8.0),
1184+
parallel_residual=False,
1185+
bias=False,
1186+
norm_class_name="RMSNorm",
1187+
mlp_class_name="GemmaMLP",
1188+
gelu_approximate="tanh",
1189+
post_attention_norm=True,
1190+
post_mlp_norm=True,
1191+
norm_qk=True,
1192+
rope_base=1000000,
1193+
rope_local_base_freq=10000,
1194+
# 5 local layers for every global layer
1195+
rope_indices=[0 if (i + 1) % 6 == 0 else 1 for i in range(48)],
1196+
),
11011197
# https://huggingface.co/google/gemma-3-27b-it/blob/main/config.json
11021198
dict(
11031199
name="Gemma-3-27b-it",

Diff for: tests/test_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def test_against_original_gemma_2(model_name, device, dtype):
812812

813813

814814
@torch.inference_mode()
815-
@pytest.mark.parametrize("model_name", ["gemma-3-27b-it"])
815+
@pytest.mark.parametrize("model_name", ["gemma-3-1b-it", "gemma-3-4b-it", "gemma-3-12b-it", "gemma-3-27b-it"])
816816
@pytest.mark.parametrize(
817817
("device", "dtype"),
818818
[

0 commit comments

Comments
 (0)