@@ -510,7 +510,7 @@ def from_pretrained(
510
510
model .print_trainable_parameters ()
511
511
512
512
if lora_path is not None :
513
- logger .info ('Load pretrained with LoRA adapter' )
513
+ logger .info (f 'Load pretrained with LoRA adapter { lora_path } ' )
514
514
from peft import LoraConfig , PeftModel
515
515
516
516
model = PeftModel .from_pretrained (model , lora_path )
@@ -689,9 +689,38 @@ def forward(
689
689
690
690
def _unsorted_segment_mean (self , data : torch .Tensor , segment_ids : torch .Tensor , num_segments : int ) -> torch .Tensor :
691
691
result_shape = (num_segments , data .size (1 ))
692
- segment_ids = segment_ids .unsqueeze (- 1 ).expand (- 1 , data .size (1 ))
692
+ segment_ids = segment_ids .unsqueeze (- 1 ).expand (- 1 , data .size (1 )) # (batch, num_embedding)
693
693
result = data .new_full (result_shape , 0 ) # init empty result tensor
694
694
count = data .new_full (result_shape , 0 )
695
- result .scatter_add_ (0 , segment_ids , data )
695
+ result .scatter_add_ (0 , segment_ids , data ) # fill the result from data to organized segment result
696
696
count .scatter_add_ (0 , segment_ids , torch .ones_like (data ))
697
697
return result / count .clamp (min = 1 )
698
+
699
+ def _sorted_segment_mean (self , data : torch .Tensor , segment_ids : torch .Tensor , num_segments : int ) -> torch .Tensor :
700
+ """
701
+ Compute the mean of each segment in data based on sorted segment_ids.
702
+
703
+ Args:
704
+ data (torch.Tensor): Input data tensor of shape (batch_size, num_embedding).
705
+ segment_ids (torch.Tensor): Sorted segment IDs tensor of shape (batch_size,).
706
+ num_segments (int): Number of unique segments.
707
+
708
+ Returns:
709
+ torch.Tensor: Tensor of shape (num_segments, num_embedding) containing the mean of each segment.
710
+ """
711
+ result = torch .zeros ((num_segments , data .size (1 )), dtype = data .dtype , device = data .device )
712
+ count = torch .zeros ((num_segments ,), dtype = torch .int32 , device = data .device )
713
+
714
+ start_idx = 0
715
+ for i in range (num_segments ):
716
+ # Find the range of indices corresponding to the current segment
717
+ while start_idx < segment_ids .size (0 ) and segment_ids [start_idx ] == i :
718
+ start_idx += 1
719
+
720
+ if start_idx > 0 and segment_ids [start_idx - 1 ] == i :
721
+ segment_slice = slice (start_idx - (start_idx - segment_ids [start_idx :].tolist ().count (i )), start_idx )
722
+ result [i ] = data [segment_slice ].sum (dim = 0 )
723
+ count [i ] = segment_slice .stop - segment_slice .start
724
+
725
+ result /= count .clamp (min = 1 ).unsqueeze (- 1 )
726
+ return result
0 commit comments