Skip to content

Commit

Permalink
adds deepspeed cross entropy for sequence parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
Ronald Rogers committed Jan 2, 2025
1 parent 428e35d commit f9f3548
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
27 changes: 26 additions & 1 deletion src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def is_deepspeed_available():


if is_deepspeed_available():
import deepspeed.comm as deepspeed_comm
from deepspeed.sequence.cross_entropy import vocab_sequence_parallel_cross_entropy
from deepspeed.sequence.layer import _SeqAllToAll
from deepspeed.utils import groups as deepspeed_groups
Expand Down Expand Up @@ -505,3 +504,29 @@ def sp_size(self):
module.sp_size = sp_size

return module


def deepspeed_ulysses_cross_entropy(
input,
target,
ignore_index=-100,
reduction="mean",
):
sp_group = deepspeed_groups._get_sequence_parallel_group()

if ignore_index != -100:
raise ValueError("ignore_index not currently supported with DeepSpeed Ulysses")

loss = vocab_sequence_parallel_cross_entropy(
input.unsqueeze(1),
target.unsqueeze(1),
sp_group=sp_group,
).squeeze(1)

if reduction == "mean":
loss = loss.nanmean()

if reduction == "sum":
loss = loss.sum()

return loss
16 changes: 15 additions & 1 deletion src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,28 @@
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, MSELoss

from transformers.integrations import is_deepspeed_available, is_deepspeed_ulysses_enabled

from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss


if is_deepspeed_available():
from ..integrations.deepspeed import deepspeed_ulysses_cross_entropy


def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if is_deepspeed_ulysses_enabled():
loss = deepspeed_ulysses_cross_entropy(
source,
target,
ignore_index=ignore_index,
reduction=reduction,
)
else:
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
Expand Down

0 comments on commit f9f3548

Please sign in to comment.