Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for DeepSpeed sequence parallelism (Ulysses) #35301

Closed
wants to merge 66 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
6d6fb4b
Add original deepspeed
zeyugao Jun 20, 2024
b5f054c
Support override the seed_worker in Trainer
zeyugao Jun 20, 2024
6f72b8a
Add some necessary check on sequence parallel argument
zeyugao Jun 20, 2024
e311c0b
Add DistributedAttention
zeyugao Jun 20, 2024
f2a6cc9
Add starcoder2 as sequence parallel supported
zeyugao Jun 20, 2024
42fd905
Add llama, mistral
zeyugao Jun 20, 2024
1c1eed2
Move DistributedSampler initialization into trainer
zeyugao Jun 21, 2024
d3b0ce0
Fix llama query shape when _upad_input
zeyugao Jun 21, 2024
8d565f8
Use all_to_all for flexiablity
zeyugao Jun 21, 2024
2727691
Support sdpa for llama and mistral
zeyugao Jun 21, 2024
fbb7e0b
Fix miss understood train_batch_size calcuation
zeyugao Jul 16, 2024
cf29d6d
Fix args.world_size calcuation in model parallel
zeyugao Jul 16, 2024
01a4cdd
Merge remote-tracking branch 'origin/main' into support-deepspeed-seq…
zeyugao Jul 16, 2024
57488b8
Run ruff check
zeyugao Jul 16, 2024
278873c
Run ruff format
zeyugao Jul 16, 2024
d94a598
DeepSpeed sequence parallelism (aka Ulysses) integration with HF tran…
xiaoxiawu-microsoft Jul 29, 2024
2cd494e
Add deepspeed sp unit test
samadejacobs Aug 7, 2024
ed1b2c7
Add deepspeed sp unit test
samadejacobs Aug 7, 2024
d660f11
Properly document args to DS SeqAllToAll
samadejacobs Aug 9, 2024
2805b7a
Add DS seq parallelism doc
samadejacobs Aug 9, 2024
c0cce19
Formatting
samadejacobs Aug 11, 2024
82ab867
isort fix
samadejacobs Aug 19, 2024
0918a67
Quality fix
samadejacobs Aug 20, 2024
9ea2571
Update test_deepspeed.py
samadejacobs Aug 21, 2024
8766b91
Update deepspeed.md
samadejacobs Oct 2, 2024
76389dd
Respond to PR comments (wrap_deepspeed)
Nov 22, 2024
f54f288
Merge branch 'wrap-ds-uly' of https://github.com/ronald-d-rogers/tran…
Nov 27, 2024
2ca042f
Merge branch 'support-deepspeed-sequence-parallel' of https://github.…
Nov 27, 2024
ae143cd
Merges support for deepspeed ulysses
Nov 27, 2024
a2a26be
fix forgot to remove some imports
Nov 28, 2024
8171941
fix accessing local variable error
Nov 28, 2024
3b22600
sharding sequences after collating data (for trl)
Dec 6, 2024
fdb3579
forgot to remove debug logging
Dec 8, 2024
f1e2284
forgot to remove debug logging
Dec 8, 2024
224c86c
sharding labels as well
Dec 10, 2024
31630d1
better naming of vars, etc.
Dec 10, 2024
3fcab88
sharding tensors before moving to gpu
Dec 11, 2024
a487ff1
adds trainer method to finalize inputs
Dec 15, 2024
9766691
merge main
Dec 16, 2024
27127a6
fix code quality issues
Dec 17, 2024
27e6810
fix code quality issues
Dec 17, 2024
308acf9
fix most tests
Dec 18, 2024
bbff3b8
inject property instead of wrapping forward
Dec 19, 2024
60e2225
remove unnecessary code
Dec 19, 2024
74f4798
using _finalize_inputs in prediction as well
Dec 20, 2024
4445da9
adds custom sp loss and just guessing rank for dist sampling
Dec 26, 2024
dc7a61a
better names
Dec 26, 2024
99f4d52
fix passing in wrong dataset to eval sampler
Dec 27, 2024
4d54fb3
back to deepspeed's ulysses cross entropy
Dec 27, 2024
0c6c561
back to just using transformer's loss to fix nan eval loss
Dec 27, 2024
a9d7c6c
fix accidentally removed import
Dec 28, 2024
428e35d
fix syntax error
Dec 28, 2024
f9f3548
adds deepspeed cross entropy for sequence parallel
Dec 29, 2024
0a4c7d2
remove no longer needed nan means from evalution loop
Jan 2, 2025
1f2490a
fix skipped tokens and not ignoring ignore indices on mean reduction …
Jan 24, 2025
660fdac
copy over ds loss and ignore ignore indices on backprop
Jan 26, 2025
6c56cb2
forgot to ignore more ignore indices
Jan 27, 2025
94572cb
back to just using hf loss but with proper backprop
Jan 28, 2025
3a80c16
fix ruff error
Jan 29, 2025
f58c44d
fix syntax error...
Jan 29, 2025
b3b981d
fix forgot to pass seq parallel group to torch dist
Jan 29, 2025
6c561a9
just leaving last token clipped for now...
Jan 30, 2025
e3db6bd
fix final sequence shard going to varlen attn impl
Jan 31, 2025
2ff42f6
add fix to other supported models
Jan 31, 2025
18f2a09
try deepspeed's loss again w/ varlen fix
Feb 1, 2025
c817d18
fix ignore mask in customized ds-loss backprop
Feb 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/source/en/deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,52 @@ Using multiple GPUs with ZeRO-3 for generation requires synchronizing the GPUs b

For Transformers>=4.28, if `synced_gpus` is automatically set to `True` if multiple GPUs are detected during generation.

### Non-Trainer Sequence Parallelism
DeepSpeed sequence parallelism, also known as [DeepSpeed Ulysses](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-ulysses/README.md), is a distributed training technique targeting long context LLM problems. Sequence parallelism would allow for a virtually indefinite growth in sequence length and model size with an increase in GPUs, unlimited by single GPU memory. DeepSpeed sequence parallelism is compatible with HuggingFace Transformers by adding 'sequence_parallel_size' and 'data_parallel_size' to the DeepSpeed configuration. Additionally, it's required that the user’s script correctly shard the input data along the sequence dimension.

```py
ds_config {
'sequence_parallel_size': 2,
'data_parallel_size': 1,
......
......
}

config = transformers.AutoConfig.from_pretrained(model_name)

model = AutoModelForCausalLM.from_pretrained(model_name,
config=config,
attn_implementation="flash_attention_2")

model, _, _, _ = deepspeed.initialize(model=model,
model_parameters=model.parameters(),
config=ds_config,
dist_init_required=True,)


spg = model.get_sequence_parallel_group()
seq_parallel_world_size = dist.get_world_size(spg)
seq_parallel_rank = dist.get_rank(spg)

for n, batch in enumerate(data_loader):
seq_length = batch["input_ids"].size(1)
assert seq_length % seq_parallel_world_size == 0
sub_seq_length = seq_length // seq_parallel_world_size
sub_seq_start = seq_parallel_rank * sub_seq_length
sub_seq_end = (seq_parallel_rank + 1) * sub_seq_length

batch["input_ids"] = batch["input_ids"][:, sub_seq_start:sub_seq_end]
batch["labels"] = batch["labels"][:, sub_seq_start:sub_seq_end]

.......

```

The HuggingFace Transformers will internally invoke DeepSpeed Ulysses to take advantage of multi-GPU optimization during the pretraining, posttraining, and fine-tuning of long context LLMs. DeepSpeed sequence parallelism is compatible with FlashAttention and is fully supported. A detailed example script is available [here](https://github.com/microsoft/DeepSpeedExamples/blob/uly-hf/post_training/sequence_parallelism/test_ulysses.py).

Also, integration with the [`Trainer`] is underway, appropriate documentation will be updated once [`Trainer`] integration feature is available.


## Troubleshoot

When you encounter an issue, you should consider whether DeepSpeed is the cause of the problem because often it isn't (unless it's super obviously and you can see DeepSpeed modules in the exception)! The first step should be to retry your setup without DeepSpeed, and if the problem persists, then you can report the issue. If the issue is a core DeepSpeed problem and unrelated to the Transformers integration, open an Issue on the [DeepSpeed repository](https://github.com/microsoft/DeepSpeed).
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
"deepspeed_load_checkpoint",
"deepspeed_optim_sched",
"is_deepspeed_available",
"is_deepspeed_ulysses_enabled",
"is_deepspeed_zero3_enabled",
"set_hf_deepspeed_config",
"unset_hf_deepspeed_config",
Expand Down Expand Up @@ -150,6 +151,7 @@
deepspeed_load_checkpoint,
deepspeed_optim_sched,
is_deepspeed_available,
is_deepspeed_ulysses_enabled,
is_deepspeed_zero3_enabled,
set_hf_deepspeed_config,
unset_hf_deepspeed_config,
Expand Down
146 changes: 144 additions & 2 deletions src/transformers/integrations/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def is_deepspeed_available():
from builtins import object as DeepSpeedConfig


if is_deepspeed_available():
import deepspeed.comm as dist
from deepspeed.sequence.layer import _SeqAllToAll
from deepspeed.utils import groups as deepspeed_groups


class HfDeepSpeedConfig(DeepSpeedConfig):
"""
This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage.
Expand Down Expand Up @@ -135,11 +141,15 @@ def fill_match(self, ds_key_long, hf_val, hf_key=None, must_match=True):
def trainer_config_process(self, args, auto_find_batch_size=False):
"""
Adjust the config with `TrainingArguments` values. This stage is run during `TrainingArguments` object
creation.
creation., sequence_parallel_size, sequence_parallel_rank)
"""
if getattr(self, "sequence_parallel_size") and self.sequence_parallel_size() > 1:
world_size = getattr(self, "data_parallel_size", args.world_size // self.sequence_parallel_size())()
else:
world_size = args.world_size
# DeepSpeed does:
# train_batch_size = world_size * train_micro_batch_size_per_gpu * gradient_accumulation_steps
train_batch_size = args.world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
train_batch_size = world_size * args.per_device_train_batch_size * args.gradient_accumulation_steps
self.fill_match(
"train_micro_batch_size_per_gpu",
args.per_device_train_batch_size,
Expand Down Expand Up @@ -298,6 +308,13 @@ def is_deepspeed_zero3_enabled():
return False


def is_deepspeed_ulysses_enabled():
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().is_sequence_parallel()
else:
return False


def deepspeed_config():
if _hf_deepspeed_config_weak_ref is not None and _hf_deepspeed_config_weak_ref() is not None:
return _hf_deepspeed_config_weak_ref().config
Expand Down Expand Up @@ -445,3 +462,128 @@ def deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path, load_module_str
raise ValueError(f"[deepspeed] failed to resume from checkpoint {checkpoint_path}")
else:
raise ValueError(f"Can't find a valid checkpoint at {checkpoint_path}")


def deepspeed_ulysses_attention(attn_func, seq_dim=1, head_dim=2):
def wrapped(*args, **kwargs):
if is_deepspeed_ulysses_enabled():
sp_group = deepspeed_groups._get_sequence_parallel_group()
scatter_idx = head_dim # Scatter on num_heads dimension
gather_idx = seq_dim # Gather on seq_len dimension
batch_dim_idx = 0 # Synonymous with the batch_first==true
args = list(args)
args[0] = _SeqAllToAll.apply(sp_group, args[0], scatter_idx, gather_idx, batch_dim_idx)
args[1] = _SeqAllToAll.apply(sp_group, args[1], scatter_idx, gather_idx, batch_dim_idx)
args[2] = _SeqAllToAll.apply(sp_group, args[2], scatter_idx, gather_idx, batch_dim_idx)
args = tuple(args)

attn_output = attn_func(*args, **kwargs)

if is_deepspeed_ulysses_enabled():
scatter_idx = seq_dim # Scatter back on seq_len dimension
gather_idx = head_dim # Gather on num_heads dimension
batch_dim_idx = 0
attn_output = _SeqAllToAll.apply(sp_group, attn_output, scatter_idx, gather_idx, batch_dim_idx)

return attn_output

return wrapped


def support_deepspeed_ulysses(module):
module._sp_size = None

@property
def sp_size(self):
if self._sp_size is None:
self._sp_size = 1
if is_deepspeed_ulysses_enabled():
self._sp_size = deepspeed_groups._get_sequence_parallel_group().size()
return self._sp_size

module.sp_size = sp_size

return module


def deepspeed_ulysses_cross_entropy(
input,
target,
ignore_index=-100,
reduction="mean",
):
sp_group = deepspeed_groups._get_sequence_parallel_group()

if ignore_index != -100:
raise ValueError("ignore_index not currently supported with DeepSpeed Ulysses")

loss = vocab_sequence_parallel_cross_entropy(
input.unsqueeze(1),
target.unsqueeze(1),
sp_group=sp_group,
).squeeze(1)

if reduction == "mean":
loss = loss[torch.nonzero(loss)].mean()

if reduction == "sum":
loss = loss.sum()

return loss


# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
class _VocabSequenceParallelCrossEntropy(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_seq_parallel_logits, target, sp_group, ignore_index):
# vocab_seq_parallel_logits: [S/P, B, V]
# target: [S/P, B]
# return: [S, B]

# Need softmax for backward
ctx.ignore_index = ignore_index
ctx.vocab_size = vocab_seq_parallel_logits.size(2)
softmax = torch.nn.functional.softmax(vocab_seq_parallel_logits, dim=-1)
loss = torch.nn.functional.nll_loss(
softmax.log().view(-1, ctx.vocab_size), target.view(-1), ignore_index=ignore_index, reduction="none"
)

sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
ctx.sp_size = sp_size
ctx.sp_rank = sp_rank
ctx.seqlen = vocab_seq_parallel_logits.size(0) * sp_size
batch_size = vocab_seq_parallel_logits.size(1)

loss_all = torch.empty(
ctx.seqlen, batch_size, dtype=vocab_seq_parallel_logits.dtype, device=vocab_seq_parallel_logits.device
)

dist.all_gather_into_tensor(loss_all, loss, group=sp_group)
ctx.save_for_backward(softmax, target)

return loss_all

@staticmethod
def backward(ctx, grad_output):
softmax, target = ctx.saved_tensors

step_seqlen = ctx.seqlen // ctx.sp_size
sp_rank = ctx.sp_rank
grad_output_part = grad_output[step_seqlen * sp_rank : step_seqlen * (sp_rank + 1), :].unsqueeze(dim=-1)

grad_input = softmax
grad_2d = grad_input.view(-1, ctx.vocab_size)
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)

ignore_mask = target.view(-1) == ctx.ignore_index
grad_2d[arange_1d[~ignore_mask], target.view(-1)[~ignore_mask]] -= 1
grad_input.mul_(grad_output_part)
grad_2d[arange_1d[ignore_mask], :] = 0
return grad_input, None, None, None, None


def vocab_sequence_parallel_cross_entropy(vocab_parallel_logits, target, sp_group, ignore_index=-100):
return _VocabSequenceParallelCrossEntropy.apply(vocab_parallel_logits, target, sp_group, ignore_index)
39 changes: 35 additions & 4 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,37 @@
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See theLicense for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, MSELoss

from transformers.integrations import is_deepspeed_available, is_deepspeed_ulysses_enabled

from .loss_deformable_detr import DeformableDetrForObjectDetectionLoss, DeformableDetrForSegmentationLoss
from .loss_for_object_detection import ForObjectDetectionLoss, ForSegmentationLoss
from .loss_rt_detr import RTDetrForObjectDetectionLoss


if is_deepspeed_available():
from deepspeed.utils import groups as deepspeed_groups

from ..integrations.deepspeed import deepspeed_ulysses_cross_entropy


def fixed_cross_entropy(source, target, num_items_in_batch: int = None, ignore_index: int = -100, **kwargs):
reduction = "sum" if num_items_in_batch is not None else "mean"
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if is_deepspeed_ulysses_enabled():
loss = deepspeed_ulysses_cross_entropy(
source,
target,
ignore_index=ignore_index,
reduction=reduction,
)
else:
loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
Expand All @@ -35,8 +51,23 @@ def ForCausalLMLoss(
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
if is_deepspeed_ulysses_enabled():
sp_group = deepspeed_groups._get_sequence_parallel_group()
sp_size = sp_group.size()
sp_rank = sp_group.rank()
sp_seqlen = logits.size(1)

shift_logits = logits.contiguous()
if sp_rank == sp_size - 1:
# add an ignore_index to the end of the labels
shift_labels = torch.cat(
(labels[..., -(sp_seqlen - 1) :], torch.full_like(labels[:, :1], ignore_index)), dim=-1
).contiguous()
else:
shift_labels = labels[..., (sp_seqlen * sp_rank) + 1 : (sp_seqlen * (sp_rank + 1)) + 1].contiguous()
else:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# Flatten the tokens
shift_logits = shift_logits.view(-1, vocab_size)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import torch
import torch.nn.functional as F

from transformers.integrations.deepspeed import deepspeed_ulysses_attention

from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal, logging


Expand Down Expand Up @@ -228,6 +230,7 @@ def fa_peft_integration_check(
deterministic_g = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"


@deepspeed_ulysses_attention
def _flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
Expand Down
32 changes: 28 additions & 4 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,13 @@
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig, GenerationMixin
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations import (
PeftAdapterMixin,
deepspeed_config,
is_deepspeed_available,
is_deepspeed_ulysses_enabled,
is_deepspeed_zero3_enabled,
)
from .loss.loss_utils import LOSS_MAPPING
from .pytorch_utils import ( # noqa: F401
Conv1D,
Expand Down Expand Up @@ -130,6 +136,11 @@
if accelerate_version >= version.parse("0.31"):
from accelerate.utils.modeling import get_state_dict_from_offload

if is_deepspeed_available():
import deepspeed
from deepspeed.utils import groups as deepspeed_groups


if is_safetensors_available():
from safetensors import safe_open
from safetensors.torch import load_file as safe_load_file
Expand Down Expand Up @@ -1284,6 +1295,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Has support for a `QuantoQuantizedCache` instance as `past_key_values`
_supports_quantized_cache = False

# Model supports sequence parallelism (DeepSpeed only)
_supports_sequence_parallel = False

# A tensor parallel plan to be applied to the model when TP is enabled. For
# top-level models, this attribute is currently defined in respective model
# code. For base models, this attribute comes from
Expand Down Expand Up @@ -4036,11 +4050,17 @@ def from_pretrained(
tp_device = None

if is_deepspeed_zero3_enabled() and not is_quantized and not _is_ds_init_called:
import deepspeed

logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
sequence_data_parallel_group = None
if is_deepspeed_ulysses_enabled() and deepspeed_groups._zero_param_parallel_is_initialized():
sequence_data_parallel_group = deepspeed_groups._get_sequence_data_parallel_group()

init_contexts = [
deepspeed.zero.Init(config_dict_or_path=deepspeed_config()),
deepspeed.zero.Init(
config_dict_or_path=deepspeed_config(),
sequence_data_parallel_group=sequence_data_parallel_group,
mpu=deepspeed_groups.mpu,
),
set_zero3_state(),
] + init_contexts
elif low_cpu_mem_usage:
Expand Down Expand Up @@ -4257,6 +4277,10 @@ def from_pretrained(
)
pass

if is_deepspeed_ulysses_enabled():
if not getattr(model, "_supports_sequence_parallel", False):
raise ValueError(f"{model.__class__} does not support sequence parallelism.")

# Dispatch model with hooks on all devices if necessary
if device_map is not None:
device_map_kwargs = {
Expand Down
Loading