Skip to content

Commit 7300dd1

Browse files
talumbaumbrenon
and
mbrenon
authored
Explicit head_dim configuration (google-ai-edge#120)
* Allow explictly setting head_dim instead of deducing it. head_dim is also moved down to the attention config: this parameter is related to attention so it has no reason to live under the main model config. This is also a requirement for the upcoming OpenELM models. * Add head_dim for T5 Attention config. * Stable Diffusion attention configs - Use head_dim from AttentionConfig * Config head_dim in generative/experimental * Fix attention configs for SD loader utilities --------- Co-authored-by: mbrenon <[email protected]>
1 parent 4b01f15 commit 7300dd1

22 files changed

+193
-84
lines changed

ai_edge_torch/generative/examples/experimental/gemma/gemma.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def __init__(self, config: cfg.ModelConfig):
7373
)
7474
self.rope_cache = attn_utils.build_rope_cache(
7575
size=config.kv_cache_max,
76-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
76+
dim=int(
77+
config.attn_config.rotary_percentage * config.attn_config.head_dim
78+
),
7779
base=10_000,
7880
condense_ratio=1,
7981
dtype=torch.float32,
@@ -125,6 +127,7 @@ def forward(
125127
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
126128
attn_config = cfg.AttentionConfig(
127129
num_heads=8,
130+
head_dim=256,
128131
num_query_groups=1,
129132
rotary_percentage=1.0,
130133
)

ai_edge_torch/generative/examples/experimental/phi/phi2.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def __init__(self, config: cfg.ModelConfig):
6868
)
6969
self.rope_cache = attn_utils.build_rope_cache(
7070
size=config.kv_cache_max,
71-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
71+
dim=int(
72+
config.attn_config.rotary_percentage * config.attn_config.head_dim
73+
),
7274
base=10_000,
7375
condense_ratio=1,
7476
dtype=torch.float32,
@@ -118,6 +120,7 @@ def forward(
118120
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
119121
attn_config = cfg.AttentionConfig(
120122
num_heads=32,
123+
head_dim=80,
121124
num_query_groups=32,
122125
rotary_percentage=0.4,
123126
qkv_use_bias=True,

ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def __init__(self, config: cfg.ModelConfig):
7070
)
7171
self.rope_cache = attn_utils.build_rope_cache(
7272
size=config.kv_cache_max,
73-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
73+
dim=int(
74+
config.attn_config.rotary_percentage * config.attn_config.head_dim
75+
),
7476
base=10_000,
7577
condense_ratio=1,
7678
dtype=torch.float32,
@@ -121,6 +123,7 @@ def forward(
121123
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
122124
attn_config = cfg.AttentionConfig(
123125
num_heads=32,
126+
head_dim=64,
124127
num_query_groups=4,
125128
rotary_percentage=1.0,
126129
)

ai_edge_torch/generative/examples/gemma/gemma.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,9 @@ def __init__(self, config: cfg.ModelConfig):
6868
)
6969
self.rope_cache = attn_utils.build_rope_cache(
7070
size=config.kv_cache_max,
71-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
71+
dim=int(
72+
config.attn_config.rotary_percentage * config.attn_config.head_dim
73+
),
7274
base=10_000,
7375
condense_ratio=1,
7476
dtype=torch.float32,
@@ -113,6 +115,7 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
113115
def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
114116
attn_config = cfg.AttentionConfig(
115117
num_heads=8,
118+
head_dim=256,
116119
num_query_groups=1,
117120
rotary_percentage=1.0,
118121
)

ai_edge_torch/generative/examples/phi2/phi2.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,9 @@ def __init__(self, config: cfg.ModelConfig):
6363
)
6464
self.rope_cache = attn_utils.build_rope_cache(
6565
size=config.kv_cache_max,
66-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
66+
dim=int(
67+
config.attn_config.rotary_percentage * config.attn_config.head_dim
68+
),
6769
base=10_000,
6870
condense_ratio=1,
6971
dtype=torch.float32,
@@ -107,6 +109,7 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
107109
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
108110
attn_config = cfg.AttentionConfig(
109111
num_heads=32,
112+
head_dim=80,
110113
num_query_groups=32,
111114
rotary_percentage=0.4,
112115
qkv_use_bias=True,

ai_edge_torch/generative/examples/stable_diffusion/clip.py

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def get_model_config() -> cfg.ModelConfig:
9292

9393
attn_config = cfg.AttentionConfig(
9494
num_heads=num_heads,
95+
head_dim=embedding_dim // num_heads,
9596
num_query_groups=num_query_groups,
9697
rotary_percentage=0.0,
9798
qkv_use_bias=True,

ai_edge_torch/generative/examples/stable_diffusion/decoder.py

+1
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
288288
normalization_config=norm_config,
289289
attention_config=layers_cfg.AttentionConfig(
290290
num_heads=1,
291+
head_dim=block_out_channels[-1],
291292
num_query_groups=1,
292293
qkv_use_bias=True,
293294
output_proj_use_bias=True,

ai_edge_torch/generative/examples/stable_diffusion/diffusion.py

+55-17
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,31 @@
195195
)
196196

197197

198+
def build_attention_config(
199+
num_heads,
200+
dim,
201+
num_query_groups,
202+
rotary_percentage=0.0,
203+
qkv_transpose_before_split=True,
204+
qkv_use_bias=False,
205+
output_proj_use_bias=True,
206+
enable_kv_cache=False,
207+
qkv_fused_interleaved=False,
208+
):
209+
210+
return layers_cfg.AttentionConfig(
211+
num_heads=num_heads,
212+
head_dim=dim // num_heads,
213+
num_query_groups=num_query_groups,
214+
rotary_percentage=rotary_percentage,
215+
qkv_transpose_before_split=qkv_transpose_before_split,
216+
qkv_use_bias=qkv_use_bias,
217+
output_proj_use_bias=output_proj_use_bias,
218+
enable_kv_cache=enable_kv_cache,
219+
qkv_fused_interleaved=qkv_fused_interleaved,
220+
)
221+
222+
198223
class TimeEmbedding(nn.Module):
199224

200225
def __init__(self, in_dim, out_dim):
@@ -267,17 +292,6 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
267292
config.in_channels, block_out_channels[0], kernel_size=3, padding=1
268293
)
269294

270-
attention_config = layers_cfg.AttentionConfig(
271-
num_heads=config.transformer_num_attention_heads,
272-
num_query_groups=config.transformer_num_attention_heads,
273-
rotary_percentage=0.0,
274-
qkv_transpose_before_split=True,
275-
qkv_use_bias=False,
276-
output_proj_use_bias=True,
277-
enable_kv_cache=False,
278-
qkv_fused_interleaved=False,
279-
)
280-
281295
# Down encoders.
282296
down_encoders = []
283297
output_channel = block_out_channels[0]
@@ -312,15 +326,23 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
312326
dim=output_channel,
313327
attention_batch_size=config.transformer_batch_size,
314328
normalization_config=config.transformer_norm_config,
315-
attention_config=attention_config,
329+
attention_config=build_attention_config(
330+
num_heads=config.transformer_num_attention_heads,
331+
dim=output_channel,
332+
num_query_groups=config.transformer_num_attention_heads,
333+
),
316334
enable_hlfb=False,
317335
),
318336
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
319337
query_dim=output_channel,
320338
cross_dim=config.transformer_cross_attention_dim,
321339
attention_batch_size=config.transformer_batch_size,
322340
normalization_config=config.transformer_norm_config,
323-
attention_config=attention_config,
341+
attention_config=build_attention_config(
342+
num_heads=config.transformer_num_attention_heads,
343+
dim=output_channel,
344+
num_query_groups=config.transformer_num_attention_heads,
345+
),
324346
enable_hlfb=False,
325347
),
326348
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
@@ -374,15 +396,23 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
374396
dim=mid_block_channels,
375397
attention_batch_size=config.transformer_batch_size,
376398
normalization_config=config.transformer_norm_config,
377-
attention_config=attention_config,
399+
attention_config=build_attention_config(
400+
num_heads=config.transformer_num_attention_heads,
401+
dim=mid_block_channels,
402+
num_query_groups=config.transformer_num_attention_heads,
403+
),
378404
enable_hlfb=False,
379405
),
380406
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
381407
query_dim=mid_block_channels,
382408
cross_dim=config.transformer_cross_attention_dim,
383409
attention_batch_size=config.transformer_batch_size,
384410
normalization_config=config.transformer_norm_config,
385-
attention_config=attention_config,
411+
attention_config=build_attention_config(
412+
num_heads=config.transformer_num_attention_heads,
413+
dim=mid_block_channels,
414+
num_query_groups=config.transformer_num_attention_heads,
415+
),
386416
enable_hlfb=False,
387417
),
388418
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
@@ -437,15 +467,23 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
437467
dim=output_channel,
438468
attention_batch_size=config.transformer_batch_size,
439469
normalization_config=config.transformer_norm_config,
440-
attention_config=attention_config,
470+
attention_config=build_attention_config(
471+
num_heads=config.transformer_num_attention_heads,
472+
dim=output_channel,
473+
num_query_groups=config.transformer_num_attention_heads,
474+
),
441475
enable_hlfb=False,
442476
),
443477
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
444478
query_dim=output_channel,
445479
cross_dim=config.transformer_cross_attention_dim,
446480
attention_batch_size=config.transformer_batch_size,
447481
normalization_config=config.transformer_norm_config,
448-
attention_config=attention_config,
482+
attention_config=build_attention_config(
483+
num_heads=config.transformer_num_attention_heads,
484+
dim=output_channel,
485+
num_query_groups=config.transformer_num_attention_heads,
486+
),
449487
enable_hlfb=False,
450488
),
451489
pre_conv_normalization_config=config.transformer_pre_conv_norm_config,

ai_edge_torch/generative/examples/t5/t5.py

+1
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,7 @@ def forward(
371371
def get_model_config_t5() -> cfg.ModelConfig:
372372
attn_config = cfg.AttentionConfig(
373373
num_heads=12,
374+
head_dim=64,
374375
num_query_groups=12,
375376
qkv_use_bias=False,
376377
relative_attention_num_buckets=32,

ai_edge_torch/generative/examples/t5/t5_attention.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def forward(
185185
) # batch size, sequence length, embedding dimensionality (n_embd)
186186
query_states = self.q_projection(x)
187187
query_states = query_states.reshape(
188-
B, T, -1, self.head_dim
188+
B, T, -1, self.config.head_dim
189189
) # (B, T, nh_q, hs)
190190

191191
if key_value_states is not None:
@@ -198,13 +198,13 @@ def forward(
198198
) # batch size, sequence length, embedding dimensionality (n_embd)
199199
key_states = self.k_projection(key_value_states)
200200
value_states = self.v_projection(key_value_states)
201-
key_states = key_states.reshape(kvB, kvT, -1, self.head_dim)
202-
value_states = value_states.reshape(kvB, kvT, -1, self.head_dim)
201+
key_states = key_states.reshape(kvB, kvT, -1, self.config.head_dim)
202+
value_states = value_states.reshape(kvB, kvT, -1, self.config.head_dim)
203203
else:
204204
key_states = self.k_projection(x)
205205
value_states = self.v_projection(x)
206-
key_states = key_states.reshape(B, T, -1, self.head_dim)
207-
value_states = value_states.reshape(B, T, -1, self.head_dim)
206+
key_states = key_states.reshape(B, T, -1, self.config.head_dim)
207+
value_states = value_states.reshape(B, T, -1, self.config.head_dim)
208208

209209
if key_value_states is None and self.kv_cache is not None:
210210
key_states, value_states = self.kv_cache.update_cache(
@@ -221,15 +221,15 @@ def forward(
221221
0
222222
) # shape (1, num_heads, query_length, key_length)
223223
else:
224-
# position_bias = torch.zeros(B, self.n_heads, T, self.head_dim, dtype=torch.float32)
224+
# position_bias = torch.zeros(B, self.n_heads, T, self.config.head_dim, dtype=torch.float32)
225225
position_bias = torch.zeros_like(mask, dtype=torch.float32)
226226

227227
mask = mask + position_bias
228228
y = self.sdpa_func(
229229
query_states,
230230
key_states,
231231
value_states,
232-
self.head_dim,
232+
self.config.head_dim,
233233
mask=mask,
234234
scale=1.0,
235235
)

ai_edge_torch/generative/examples/test_models/toy_model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def __init__(self, config: cfg.ModelConfig) -> None:
4343
)
4444
self.rope_cache = attn_utils.build_rope_cache(
4545
size=config.max_seq_len,
46-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
46+
dim=int(
47+
config.attn_config.rotary_percentage * config.attn_config.head_dim
48+
),
4749
base=10_000,
4850
condense_ratio=1,
4951
dtype=torch.float32,
@@ -72,6 +74,7 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
7274
def get_model_config() -> cfg.ModelConfig:
7375
attn_config = cfg.AttentionConfig(
7476
num_heads=32,
77+
head_dim=4,
7578
num_query_groups=4,
7679
rotary_percentage=1.0,
7780
enable_kv_cache=False,

ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,9 @@ def __init__(self, config: cfg.ModelConfig) -> None:
4646
)
4747
self.rope_cache = attn_utils.build_rope_cache(
4848
size=config.max_seq_len,
49-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
49+
dim=int(
50+
config.attn_config.rotary_percentage * config.attn_config.head_dim
51+
),
5052
base=10_000,
5153
condense_ratio=1,
5254
dtype=torch.float32,
@@ -90,7 +92,7 @@ def _export_stablehlo_mlir(model, args):
9092

9193
def get_model_config() -> cfg.ModelConfig:
9294
attn_config = cfg.AttentionConfig(
93-
num_heads=32, num_query_groups=4, rotary_percentage=1.0
95+
num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
9496
)
9597
ff_config = cfg.FeedForwardConfig(
9698
type=cfg.FeedForwardType.GATED,

ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ def __init__(self, config: cfg.ModelConfig) -> None:
4545
)
4646
self.rope_cache = attn_utils.build_rope_cache(
4747
size=config.max_seq_len,
48-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
48+
dim=int(
49+
config.attn_config.rotary_percentage * config.attn_config.head_dim
50+
),
4951
base=10_000,
5052
condense_ratio=1,
5153
dtype=torch.float32,
@@ -78,7 +80,7 @@ def _export_stablehlo_mlir(model, args):
7880

7981
def get_model_config() -> cfg.ModelConfig:
8082
attn_config = cfg.AttentionConfig(
81-
num_heads=32, num_query_groups=4, rotary_percentage=1.0
83+
num_heads=32, head_dim=4, num_query_groups=4, rotary_percentage=1.0
8284
)
8385
ff_config = cfg.FeedForwardConfig(
8486
type=cfg.FeedForwardType.GATED,

ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@ def __init__(self, config: cfg.ModelConfig):
6464
)
6565
self.rope_cache = attn_utils.build_rope_cache(
6666
size=config.kv_cache_max,
67-
dim=int(config.attn_config.rotary_percentage * config.head_dim),
67+
dim=int(
68+
config.attn_config.rotary_percentage * config.attn_config.head_dim
69+
),
6870
base=10_000,
6971
condense_ratio=1,
7072
dtype=torch.float32,
@@ -109,6 +111,7 @@ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
109111
def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
110112
attn_config = cfg.AttentionConfig(
111113
num_heads=32,
114+
head_dim=64,
112115
num_query_groups=4,
113116
rotary_percentage=1.0,
114117
)

ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def get_model_config() -> unet_cfg.AttentionBlock2DConfig:
9999
normalization_config=norm_config,
100100
attention_config=layers_cfg.AttentionConfig(
101101
num_heads=1,
102+
head_dim=block_out_channels[-1],
102103
num_query_groups=1,
103104
qkv_use_bias=True,
104105
output_proj_use_bias=True,

0 commit comments

Comments
 (0)