@@ -242,97 +242,85 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
242242
243243 def _precompute_freqs_cis (self , axes_dim : List [int ], axes_lens : List [int ], theta : int ) -> List [torch .Tensor ]:
244244 freqs_cis = []
245- # Use float32 for MPS compatibility
246- dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
245+ freqs_dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
247246 for i , (d , e ) in enumerate (zip (axes_dim , axes_lens )):
248- emb = get_1d_rotary_pos_embed (d , e , theta = self .theta , freqs_dtype = dtype )
247+ emb = get_1d_rotary_pos_embed (d , e , theta = self .theta , freqs_dtype = freqs_dtype )
249248 freqs_cis .append (emb )
250249 return freqs_cis
251250
252251 def _get_freqs_cis (self , ids : torch .Tensor ) -> torch .Tensor :
252+ device = ids .device
253+ if ids .device .type == "mps" :
254+ ids = ids .to ("cpu" )
255+
253256 result = []
254257 for i in range (len (self .axes_dim )):
255258 freqs = self .freqs_cis [i ].to (ids .device )
256259 index = ids [:, :, i : i + 1 ].repeat (1 , 1 , freqs .shape [- 1 ]).to (torch .int64 )
257260 result .append (torch .gather (freqs .unsqueeze (0 ).repeat (index .shape [0 ], 1 , 1 ), dim = 1 , index = index ))
258- return torch .cat (result , dim = - 1 )
261+ return torch .cat (result , dim = - 1 ). to ( device )
259262
260263 def forward (self , hidden_states : torch .Tensor , attention_mask : torch .Tensor ):
261- batch_size = len (hidden_states )
262- p_h = p_w = self .patch_size
263- device = hidden_states [0 ].device
264+ batch_size , channels , height , width = hidden_states .shape
265+ p = self .patch_size
266+ post_patch_height , post_patch_width = height // p , width // p
267+ image_seq_len = post_patch_height * post_patch_width
268+ device = hidden_states .device
264269
270+ encoder_seq_len = attention_mask .shape [1 ]
265271 l_effective_cap_len = attention_mask .sum (dim = 1 ).tolist ()
266- # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
267- img_sizes = [(img .size (1 ), img .size (2 )) for img in hidden_states ]
268- l_effective_img_len = [(H // p_h ) * (W // p_w ) for (H , W ) in img_sizes ]
269-
270- max_seq_len = max ((cap_len + img_len for cap_len , img_len in zip (l_effective_cap_len , l_effective_img_len )))
271- max_img_len = max (l_effective_img_len )
272+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len ]
273+ max_seq_len = max (seq_lengths )
272274
275+ # Create position IDs
273276 position_ids = torch .zeros (batch_size , max_seq_len , 3 , dtype = torch .int32 , device = device )
274277
275- for i in range (batch_size ):
276- cap_len = l_effective_cap_len [i ]
277- img_len = l_effective_img_len [i ]
278- H , W = img_sizes [i ]
279- H_tokens , W_tokens = H // p_h , W // p_w
280- assert H_tokens * W_tokens == img_len
278+ for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
279+ # add caption position ids
280+ position_ids [i , :cap_seq_len , 0 ] = torch .arange (cap_seq_len , dtype = torch .int32 , device = device )
281+ position_ids [i , cap_seq_len :seq_len , 0 ] = cap_seq_len
281282
282- position_ids [i , :cap_len , 0 ] = torch .arange (cap_len , dtype = torch .int32 , device = device )
283- position_ids [i , cap_len : cap_len + img_len , 0 ] = cap_len
283+ # add image position ids
284284 row_ids = (
285- torch .arange (H_tokens , dtype = torch .int32 , device = device ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
285+ torch .arange (post_patch_height , dtype = torch .int32 , device = device )
286+ .view (- 1 , 1 )
287+ .repeat (1 , post_patch_width )
288+ .flatten ()
286289 )
287290 col_ids = (
288- torch .arange (W_tokens , dtype = torch .int32 , device = device ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
291+ torch .arange (post_patch_width , dtype = torch .int32 , device = device )
292+ .view (1 , - 1 )
293+ .repeat (post_patch_height , 1 )
294+ .flatten ()
289295 )
290- position_ids [i , cap_len : cap_len + img_len , 1 ] = row_ids
291- position_ids [i , cap_len : cap_len + img_len , 2 ] = col_ids
296+ position_ids [i , cap_seq_len : seq_len , 1 ] = row_ids
297+ position_ids [i , cap_seq_len : seq_len , 2 ] = col_ids
292298
299+ # Get combined rotary embeddings
293300 freqs_cis = self ._get_freqs_cis (position_ids )
294301
295- cap_freqs_cis_shape = list (freqs_cis .shape )
296- cap_freqs_cis_shape [1 ] = attention_mask .shape [1 ]
297- cap_freqs_cis = torch .zeros (* cap_freqs_cis_shape , device = device , dtype = freqs_cis .dtype )
298-
299- img_freqs_cis_shape = list (freqs_cis .shape )
300- img_freqs_cis_shape [1 ] = max_img_len
301- img_freqs_cis = torch .zeros (* img_freqs_cis_shape , device = device , dtype = freqs_cis .dtype )
302-
303- for i in range (batch_size ):
304- cap_len = l_effective_cap_len [i ]
305- img_len = l_effective_img_len [i ]
306- cap_freqs_cis [i , :cap_len ] = freqs_cis [i , :cap_len ]
307- img_freqs_cis [i , :img_len ] = freqs_cis [i , cap_len : cap_len + img_len ]
308-
309- flat_hidden_states = []
310- for i in range (batch_size ):
311- img = hidden_states [i ]
312- C , H , W = img .size ()
313- img = img .view (C , H // p_h , p_h , W // p_w , p_w ).permute (1 , 3 , 2 , 4 , 0 ).flatten (2 ).flatten (0 , 1 )
314- flat_hidden_states .append (img )
315- hidden_states = flat_hidden_states
316- padded_img_embed = torch .zeros (
317- batch_size , max_img_len , hidden_states [0 ].shape [- 1 ], device = device , dtype = hidden_states [0 ].dtype
302+ # create separate rotary embeddings for captions and images
303+ cap_freqs_cis = torch .zeros (
304+ batch_size , encoder_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
318305 )
319- padded_img_mask = torch .zeros (batch_size , max_img_len , dtype = torch .bool , device = device )
320- for i in range (batch_size ):
321- padded_img_embed [i , : l_effective_img_len [i ]] = hidden_states [i ]
322- padded_img_mask [i , : l_effective_img_len [i ]] = True
323-
324- return (
325- padded_img_embed ,
326- padded_img_mask ,
327- img_sizes ,
328- l_effective_cap_len ,
329- l_effective_img_len ,
330- freqs_cis ,
331- cap_freqs_cis ,
332- img_freqs_cis ,
333- max_seq_len ,
306+ img_freqs_cis = torch .zeros (
307+ batch_size , image_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
308+ )
309+
310+ for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
311+ cap_freqs_cis [i , :cap_seq_len ] = freqs_cis [i , :cap_seq_len ]
312+ img_freqs_cis [i , :image_seq_len ] = freqs_cis [i , cap_seq_len :seq_len ]
313+
314+ # image patch embeddings
315+ hidden_states = (
316+ hidden_states .view (batch_size , channels , post_patch_height , p , post_patch_width , p )
317+ .permute (0 , 2 , 4 , 3 , 5 , 1 )
318+ .flatten (3 )
319+ .flatten (1 , 2 )
334320 )
335321
322+ return hidden_states , cap_freqs_cis , img_freqs_cis , freqs_cis , l_effective_cap_len , seq_lengths
323+
336324
337325class Lumina2Transformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin ):
338326 r"""
@@ -472,75 +460,63 @@ def forward(
472460 hidden_states : torch .Tensor ,
473461 timestep : torch .Tensor ,
474462 encoder_hidden_states : torch .Tensor ,
475- attention_mask : torch .Tensor ,
476- use_mask_in_transformer : bool = True ,
463+ encoder_attention_mask : torch .Tensor ,
477464 return_dict : bool = True ,
478465 ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
479- batch_size = hidden_states .size (0 )
480-
481466 # 1. Condition, positional & patch embedding
467+ batch_size , _ , height , width = hidden_states .shape
468+
482469 temb , encoder_hidden_states = self .time_caption_embed (hidden_states , timestep , encoder_hidden_states )
483470
484471 (
485472 hidden_states ,
486- hidden_mask ,
487- hidden_sizes ,
488- encoder_hidden_len ,
489- hidden_len ,
490- joint_rotary_emb ,
491- encoder_rotary_emb ,
492- hidden_rotary_emb ,
493- max_seq_len ,
494- ) = self .rope_embedder (hidden_states , attention_mask )
473+ context_rotary_emb ,
474+ noise_rotary_emb ,
475+ rotary_emb ,
476+ encoder_seq_lengths ,
477+ seq_lengths ,
478+ ) = self .rope_embedder (hidden_states , encoder_attention_mask )
495479
496480 hidden_states = self .x_embedder (hidden_states )
497481
498482 # 2. Context & noise refinement
499483 for layer in self .context_refiner :
500- # NOTE: mask not used for performance
501- encoder_hidden_states = layer (
502- encoder_hidden_states , attention_mask if use_mask_in_transformer else None , encoder_rotary_emb
503- )
484+ encoder_hidden_states = layer (encoder_hidden_states , encoder_attention_mask , context_rotary_emb )
504485
505486 for layer in self .noise_refiner :
506- # NOTE: mask not used for performance
507- hidden_states = layer (
508- hidden_states , hidden_mask if use_mask_in_transformer else None , hidden_rotary_emb , temb
509- )
487+ hidden_states = layer (hidden_states , None , noise_rotary_emb , temb )
488+
489+ # 3. Joint Transformer blocks
490+ max_seq_len = max (seq_lengths )
491+ use_mask = len (set (seq_lengths )) > 1
492+
493+ attention_mask = hidden_states .new_zeros (batch_size , max_seq_len , dtype = torch .bool )
494+ joint_hidden_states = hidden_states .new_zeros (batch_size , max_seq_len , self .config .hidden_size )
495+ for i , (encoder_seq_len , seq_len ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
496+ attention_mask [i , :seq_len ] = True
497+ joint_hidden_states [i , :encoder_seq_len ] = encoder_hidden_states [i , :encoder_seq_len ]
498+ joint_hidden_states [i , encoder_seq_len :seq_len ] = hidden_states [i ]
499+
500+ hidden_states = joint_hidden_states
510501
511- # 3. Attention mask preparation
512- mask = hidden_states .new_zeros (batch_size , max_seq_len , dtype = torch .bool )
513- padded_hidden_states = hidden_states .new_zeros (batch_size , max_seq_len , self .config .hidden_size )
514- for i in range (batch_size ):
515- cap_len = encoder_hidden_len [i ]
516- img_len = hidden_len [i ]
517- mask [i , : cap_len + img_len ] = True
518- padded_hidden_states [i , :cap_len ] = encoder_hidden_states [i , :cap_len ]
519- padded_hidden_states [i , cap_len : cap_len + img_len ] = hidden_states [i , :img_len ]
520- hidden_states = padded_hidden_states
521-
522- # 4. Transformer blocks
523502 for layer in self .layers :
524- # NOTE: mask not used for performance
525503 if torch .is_grad_enabled () and self .gradient_checkpointing :
526504 hidden_states = self ._gradient_checkpointing_func (
527- layer , hidden_states , mask if use_mask_in_transformer else None , joint_rotary_emb , temb
505+ layer , hidden_states , attention_mask if use_mask else None , rotary_emb , temb
528506 )
529507 else :
530- hidden_states = layer (hidden_states , mask if use_mask_in_transformer else None , joint_rotary_emb , temb )
508+ hidden_states = layer (hidden_states , attention_mask if use_mask else None , rotary_emb , temb )
531509
532- # 5 . Output norm & projection & unpatchify
510+ # 4 . Output norm & projection
533511 hidden_states = self .norm_out (hidden_states , temb )
534512
535- height_tokens = width_tokens = self .config .patch_size
513+ # 5. Unpatchify
514+ p = self .config .patch_size
536515 output = []
537- for i in range (len (hidden_sizes )):
538- height , width = hidden_sizes [i ]
539- begin = encoder_hidden_len [i ]
540- end = begin + (height // height_tokens ) * (width // width_tokens )
516+ for i , (encoder_seq_len , seq_len ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
541517 output .append (
542- hidden_states [i ][begin : end ]
543- .view (height // height_tokens , width // width_tokens , height_tokens , width_tokens , self .out_channels )
518+ hidden_states [i ][encoder_seq_len : seq_len ]
519+ .view (height // p , width // p , p , p , self .out_channels )
544520 .permute (4 , 0 , 2 , 1 , 3 )
545521 .flatten (3 , 4 )
546522 .flatten (1 , 2 )
0 commit comments