|
195 | 195 | )
|
196 | 196 |
|
197 | 197 |
|
| 198 | +def build_attention_config( |
| 199 | + num_heads, |
| 200 | + dim, |
| 201 | + num_query_groups, |
| 202 | + rotary_percentage=0.0, |
| 203 | + qkv_transpose_before_split=True, |
| 204 | + qkv_use_bias=False, |
| 205 | + output_proj_use_bias=True, |
| 206 | + enable_kv_cache=False, |
| 207 | + qkv_fused_interleaved=False, |
| 208 | +): |
| 209 | + |
| 210 | + return layers_cfg.AttentionConfig( |
| 211 | + num_heads=num_heads, |
| 212 | + head_dim=dim // num_heads, |
| 213 | + num_query_groups=num_query_groups, |
| 214 | + rotary_percentage=rotary_percentage, |
| 215 | + qkv_transpose_before_split=qkv_transpose_before_split, |
| 216 | + qkv_use_bias=qkv_use_bias, |
| 217 | + output_proj_use_bias=output_proj_use_bias, |
| 218 | + enable_kv_cache=enable_kv_cache, |
| 219 | + qkv_fused_interleaved=qkv_fused_interleaved, |
| 220 | + ) |
| 221 | + |
| 222 | + |
198 | 223 | class TimeEmbedding(nn.Module):
|
199 | 224 |
|
200 | 225 | def __init__(self, in_dim, out_dim):
|
@@ -267,17 +292,6 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
|
267 | 292 | config.in_channels, block_out_channels[0], kernel_size=3, padding=1
|
268 | 293 | )
|
269 | 294 |
|
270 |
| - attention_config = layers_cfg.AttentionConfig( |
271 |
| - num_heads=config.transformer_num_attention_heads, |
272 |
| - num_query_groups=config.transformer_num_attention_heads, |
273 |
| - rotary_percentage=0.0, |
274 |
| - qkv_transpose_before_split=True, |
275 |
| - qkv_use_bias=False, |
276 |
| - output_proj_use_bias=True, |
277 |
| - enable_kv_cache=False, |
278 |
| - qkv_fused_interleaved=False, |
279 |
| - ) |
280 |
| - |
281 | 295 | # Down encoders.
|
282 | 296 | down_encoders = []
|
283 | 297 | output_channel = block_out_channels[0]
|
@@ -312,15 +326,23 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
|
312 | 326 | dim=output_channel,
|
313 | 327 | attention_batch_size=config.transformer_batch_size,
|
314 | 328 | normalization_config=config.transformer_norm_config,
|
315 |
| - attention_config=attention_config, |
| 329 | + attention_config=build_attention_config( |
| 330 | + num_heads=config.transformer_num_attention_heads, |
| 331 | + dim=output_channel, |
| 332 | + num_query_groups=config.transformer_num_attention_heads, |
| 333 | + ), |
316 | 334 | enable_hlfb=False,
|
317 | 335 | ),
|
318 | 336 | cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
319 | 337 | query_dim=output_channel,
|
320 | 338 | cross_dim=config.transformer_cross_attention_dim,
|
321 | 339 | attention_batch_size=config.transformer_batch_size,
|
322 | 340 | normalization_config=config.transformer_norm_config,
|
323 |
| - attention_config=attention_config, |
| 341 | + attention_config=build_attention_config( |
| 342 | + num_heads=config.transformer_num_attention_heads, |
| 343 | + dim=output_channel, |
| 344 | + num_query_groups=config.transformer_num_attention_heads, |
| 345 | + ), |
324 | 346 | enable_hlfb=False,
|
325 | 347 | ),
|
326 | 348 | pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
@@ -374,15 +396,23 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
|
374 | 396 | dim=mid_block_channels,
|
375 | 397 | attention_batch_size=config.transformer_batch_size,
|
376 | 398 | normalization_config=config.transformer_norm_config,
|
377 |
| - attention_config=attention_config, |
| 399 | + attention_config=build_attention_config( |
| 400 | + num_heads=config.transformer_num_attention_heads, |
| 401 | + dim=mid_block_channels, |
| 402 | + num_query_groups=config.transformer_num_attention_heads, |
| 403 | + ), |
378 | 404 | enable_hlfb=False,
|
379 | 405 | ),
|
380 | 406 | cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
381 | 407 | query_dim=mid_block_channels,
|
382 | 408 | cross_dim=config.transformer_cross_attention_dim,
|
383 | 409 | attention_batch_size=config.transformer_batch_size,
|
384 | 410 | normalization_config=config.transformer_norm_config,
|
385 |
| - attention_config=attention_config, |
| 411 | + attention_config=build_attention_config( |
| 412 | + num_heads=config.transformer_num_attention_heads, |
| 413 | + dim=mid_block_channels, |
| 414 | + num_query_groups=config.transformer_num_attention_heads, |
| 415 | + ), |
386 | 416 | enable_hlfb=False,
|
387 | 417 | ),
|
388 | 418 | pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
@@ -437,15 +467,23 @@ def __init__(self, config: unet_cfg.DiffusionModelConfig):
|
437 | 467 | dim=output_channel,
|
438 | 468 | attention_batch_size=config.transformer_batch_size,
|
439 | 469 | normalization_config=config.transformer_norm_config,
|
440 |
| - attention_config=attention_config, |
| 470 | + attention_config=build_attention_config( |
| 471 | + num_heads=config.transformer_num_attention_heads, |
| 472 | + dim=output_channel, |
| 473 | + num_query_groups=config.transformer_num_attention_heads, |
| 474 | + ), |
441 | 475 | enable_hlfb=False,
|
442 | 476 | ),
|
443 | 477 | cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
444 | 478 | query_dim=output_channel,
|
445 | 479 | cross_dim=config.transformer_cross_attention_dim,
|
446 | 480 | attention_batch_size=config.transformer_batch_size,
|
447 | 481 | normalization_config=config.transformer_norm_config,
|
448 |
| - attention_config=attention_config, |
| 482 | + attention_config=build_attention_config( |
| 483 | + num_heads=config.transformer_num_attention_heads, |
| 484 | + dim=output_channel, |
| 485 | + num_query_groups=config.transformer_num_attention_heads, |
| 486 | + ), |
449 | 487 | enable_hlfb=False,
|
450 | 488 | ),
|
451 | 489 | pre_conv_normalization_config=config.transformer_pre_conv_norm_config,
|
|
0 commit comments