From f9f354813a985d00c8817297d675ca302432690f Mon Sep 17 00:00:00 2001 From: Ronald Rogers Date: Sun, 29 Dec 2024 12:51:24 -0500 Subject: [PATCH] adds deepspeed cross entropy for sequence parallel --- src/transformers/integrations/deepspeed.py | 27 +++++++++++++++++++++- src/transformers/loss/loss_utils.py | 16 ++++++++++++- 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/deepspeed.py b/src/transformers/integrations/deepspeed.py index a721a79cab3f..17aa5141a48b 100644 --- a/src/transformers/integrations/deepspeed.py +++ b/src/transformers/integrations/deepspeed.py @@ -57,7 +57,6 @@ def is_deepspeed_available(): if is_deepspeed_available(): - import deepspeed.comm as deepspeed_comm from deepspeed.sequence.cross_entropy import vocab_sequence_parallel_cross_entropy from deepspeed.sequence.layer import _SeqAllToAll from deepspeed.utils import groups as deepspeed_groups @@ -505,3 +504,29 @@ def sp_size(self): module.sp_size = sp_size return module + + +def deepspeed_ulysses_cross_entropy( + input, + target, + ignore_index=-100, + reduction="mean", +): + sp_group = deepspeed_groups._get_sequence_parallel_group() + + if ignore_index != -100: + raise ValueError("ignore_index not currently supported with DeepSpeed Ulysses") + + loss = vocab_sequence_parallel_cross_entropy( + input.unsqueeze(1), + target.unsqueeze(1), + sp_group=sp_group, + ).squeeze(1) + + if reduction == "mean": + loss = loss.nanmean() + + if reduction == "sum": + loss = loss.sum() + + return loss diff --git a/src/transformers/loss/loss_utils.py b/src/transformers/loss/loss_utils.py index 820daf8c3613..0e8f718611f8 100644 --- a/src/transformers/loss/loss_utils.py +++ b/src/transformers/loss/loss_utils.py @@ -16,14 +16,28 @@ import torch.nn as nn from torch.nn import BCEWithLogitsLoss, MSELoss +from transformers.integrations import is_deepspeed_available, is_deepspeed_ulysses_enabled + from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss from .loss_rt_detr import RTDetrForObjectDetectionLoss +if is_deepspeed_available(): + from ..integrations.deepspeed import deepspeed_ulysses_cross_entropy + + def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs): reduction = "sum" if num_items_in_batch is not None else "mean" - loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) + if is_deepspeed_ulysses_enabled(): + loss = deepspeed_ulysses_cross_entropy( + source, + target, + ignore_index=ignore_index, + reduction=reduction, + ) + else: + loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction) if reduction == "sum": loss = loss / num_items_in_batch return loss