Skip to content

Commit

Permalink
Merge branch 'master' into patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 22, 2025
2 parents bc2c9fa + de4596b commit 4f9c2a7
Show file tree
Hide file tree
Showing 14 changed files with 108 additions and 85 deletions.
2 changes: 1 addition & 1 deletion accelerator/real_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_accelerator():
if accelerator_name is None:
# borrow this log from PR#5084
if accel_logger is not None:
accel_logger.warn(
accel_logger.warning(
"Setting accelerator to CPU. If you have GPU or other accelerator, we were unable to detect it.")
# cpu added as catch-all when accelerator detection fails
accelerator_name = "cpu"
Expand Down
21 changes: 21 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
dtype = self.weight.dtype if self.weight is not None else None
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class LmHeadLinearAllreduce(nn.Module):

Expand Down Expand Up @@ -120,6 +127,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
dtype = self.weight.dtype if self.weight is not None else None
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class LinearLayer(nn.Module):

Expand All @@ -144,6 +158,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape
dtype = self.weight.dtype
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class Normalize(nn.Module):

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/base_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def load_hp_checkpoint_state_from_checkpoint_dir(self, lp_groups_name: str, chec

tp_rank = bwc_tensor_model_parallel_rank(mpu=self.mpu)
if self.mpu is None:
logger.warn("MPU is not provided, setting tp size to 1 in checkpoint loading.")
logger.warning("MPU is not provided, setting tp size to 1 in checkpoint loading.")
tp_world_size = 1
else:
tp_world_size = self.mpu.get_slice_parallel_world_size() if hasattr(self.mpu, "get_slice_parallel_world_size") \
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/comm/compressed.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro

compensated_server_m.add_(server_error)

server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/comm/hccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_erro

compensated_server_m.add_(server_error)

server_scale = torch.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())
server_scale = torch.linalg.norm(compensated_server_m) / np.sqrt(compensated_server_m.numel())

server_error.set_(compensated_server_m -
server_scale * compensated_server_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))
Expand Down
4 changes: 2 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3120,7 +3120,7 @@ def _get_all_zero_checkpoints(self, load_dir, tag):
if bf16_mode is not self.bfloat16_enabled():
checkpoint_bit16 = BFLOAT16 if bf16_mode else FP16
engine_bit16 = BFLOAT16 if self.bfloat16_enabled() else FP16
logger.warn(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
logger.warning(f'Loading {checkpoint_bit16} zero checkpoints into {engine_bit16} training engine')
return self._get_all_zero_checkpoint_state_dicts(zero_ckpt_names)

return None
Expand Down Expand Up @@ -3276,7 +3276,7 @@ def _save_moe_checkpoint(self, save_dir, tag, client_state={}, exclude_frozen_pa

local_expert_id = None
if not m:
logger.warn(f'No expert found in key {key}.')
logger.warning(f'No expert found in key {key}.')
else:
local_expert_id = m.group(1)

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/onebit/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def step(self, closure=None, grads=None):
# This is used to reduce compression error during compression stage.
momentum_scales = []
for group in self.param_groups:
momentum_scales.append([(torch.linalg.norm(self.state[p]['exp_avg']) /
momentum_scales.append([(torch.linalg.vector_norm(self.state[p]['exp_avg']) /
np.sqrt(torch.numel(self.state[p]['exp_avg']))).item()
for p in group['params']])
united_scale = sum([sum(x) for x in momentum_scales]) / sum([len(x) for x in momentum_scales])
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def _initialize_lr(self, optimizer, cycle_min_lr, cycle_max_lr, decay_lr_rate, l
def _initialize_momentum(self, optimizer, cycle_min_mom, cycle_max_mom, decay_mom_rate, last_batch_iteration):
if 'betas' not in optimizer.defaults:
optimizer_name = type(optimizer).__name__
logger.warn(
logger.warning(
f"cycle_momentum is disabled because optimizer {optimizer_name} does not support momentum, no betas attribute in defaults"
)
self.cycle_momentum = False
Expand Down
15 changes: 3 additions & 12 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,15 +546,10 @@ def _setup_for_real_optimizer(self):
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
max_partition_numel = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()
max_partition_numel = max(max_partition_numel, param.partition_numel())
if self.offload_optimizer:
self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory(
torch.empty(max_partition_numel, device=self.device))

def _link_all_hp_params(self):
for p in self.module.parameters():
Expand Down Expand Up @@ -1510,13 +1505,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
offload_fp32_gradients[i].append(grad_buffer.float())
offload_fp32_offsets[i].append(dest_offset)
else:
buffer_numel = grad_buffer.numel()
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0, dest_offset, buffer_numel)
self.pinned_grad_buffer[:buffer_numel].copy_(
grad_buffer.to(dtype=torch.float32, non_blocking=True))
get_accelerator().synchronize()
fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True)
0, dest_offset, grad_buffer.numel())
fp32_grad_tensor.copy_(grad_buffer.float())

# free the gradient
if not get_accelerator().is_synchronized_device():
Expand Down Expand Up @@ -2101,7 +2092,7 @@ def step(self, closure=None):
return

norm_groups = self._get_norm_groups()
scaled_global_grad_norm = torch.linalg.norm(torch.stack(norm_groups))
scaled_global_grad_norm = torch.linalg.vector_norm(torch.stack(norm_groups))

# Stash unscaled gradient norm
self._global_grad_norm = scaled_global_grad_norm / self.loss_scale
Expand Down
7 changes: 4 additions & 3 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ def _configure_moe_settings(self):
assert self.contiguous_gradients, "Contiguous Gradients in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"
# NOTE: To run ZeRO stage 1 with MoE, we need to set self.contiguous_gradients to True or ignore the assertion
if not self.partition_gradients and not self.contiguous_gradients:
logger.warn(
logger.warning(
"ZeRO Stage 1 has not been thoroughly tested with MoE. This configuration is still experimental.")
assert self.reduce_scatter, "Reduce Scatter in ZeRO Stage 2 must be set to True for MoE. Other code paths are not tested with MoE"

Expand Down Expand Up @@ -1691,7 +1691,8 @@ def get_grad_norm_direct(self, gradients, params, norm_type=2):
continue
if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0):
all_norms.append(
torch.norm(g.data.double().detach(), norm_type).to(get_accelerator().current_device_name()))
torch.linalg.vector_norm(g.data.double().detach(),
ord=norm_type).to(get_accelerator().current_device_name()))
if len(all_norms) > 0:
total_norm = torch.stack(all_norms).square().sum().float()
else:
Expand Down Expand Up @@ -1795,7 +1796,7 @@ def scaled_global_norm(self, norm_type=2):
self._average_expert_grad_norms(norm_groups)

# calculating L2 norm
return torch.norm(torch.stack(norm_groups), p=norm_type)
return torch.linalg.vector_norm(torch.stack(norm_groups), ord=norm_type)

def get_bit16_param_group(self, group_no):
bit16_partitions = self.parallel_partitioned_bit16_groups[group_no]
Expand Down
129 changes: 70 additions & 59 deletions deepspeed/sequence/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,71 @@
from deepspeed.utils import groups


def _generate_layout_params(scatter_idx, batch_dim_idx, seq_world_size, input):
"""
This function generates the parameters required for `permute` and `reshape` operations,
which are used to process data before and after `all2all` communication.
"""
if batch_dim_idx == 0:
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [bs, seq_world_size, global_seq_len // seq_world_size, num_local_head, head_dim]
pre_all2all_permute_idx = (1, 0, 2, 3, 4)

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, global_seq_len // seq_world_size, seq_world_size * num_local_head, head_dim]
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [bs, local_seq_len, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)

post_all2all_permute_idx = (1, 0, 2, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * local_seq_len, num_total_head // seq_world_size, head_dim]
else:
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
pre_all2all_inp_shape = [seq_world_size, global_seq_len // seq_world_size, bs, num_local_head, head_dim]
pre_all2all_permute_idx = None

post_all2all_permute_idx = (1, 2, 0, 3, 4)
post_all2all_res_shape = [bs, seq_world_size * global_seq_len, num_local_head // seq_world_size, head_dim]
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
pre_all2all_inp_shape = [local_seq_len, bs, seq_world_size, num_total_head // seq_world_size, head_dim]
pre_all2all_permute_idx = (2, 0, 1, 3, 4)
post_all2all_permute_idx = None
post_all2all_res_shape = [local_seq_len * seq_world_size, bs, num_total_head // seq_world_size, head_dim]

return pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape


def post_all2all(permute_idx, res_shape):
"""
Post-processing function for `all2all` communication.
"""

def post_func(input):
if permute_idx is not None:
input = input.permute(permute_idx).contiguous()
output = input.reshape(res_shape).contiguous()

return output

return post_func


def pre_all2all_fun(permute_idx, inp_shape, input):
"""
Pre-processing function for `all2all` communication.
"""
input_t = input.reshape(inp_shape).contiguous()
if permute_idx is not None:
input_t = input_t.permute(permute_idx).contiguous()
return input_t


def _rotate_half(x):
"""
change sign so the last dimension becomes [-odd, +even]
Expand Down Expand Up @@ -43,32 +108,6 @@ def apply_rotary_pos_emb(t, freqs_cos, freqs_sin):
return res


def post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, seq_len, num_head, head_dim):

def post_func(input):
if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(bs, seq_len // seq_world_size, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.permute(1, 0, 2, 3, 4).contiguous()
output = output.reshape(bs, seq_world_size * seq_len, num_head // seq_world_size,
head_dim).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
output = input.permute(1, 2, 0, 3, 4).contiguous()
output = output.reshape(seq_len // seq_world_size, bs, seq_world_size * num_head,
head_dim).contiguous()
else:
output = input.reshape(seq_len * seq_world_size, bs, num_head // seq_world_size, head_dim).contiguous()
return output

return post_func


def uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group):
seq_world_size = dist.get_world_size(group)
inp_shape = list(input.shape)
Expand Down Expand Up @@ -195,39 +234,12 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
assert async_op == False, "uneven head sp does not support async op"
return uneven_heads_all2all(input, scatter_idx, gather_idx, batch_dim_idx, group)

if batch_dim_idx == 0:
# b, s, n, h
if scatter_idx < 2:
bs, global_seq_len, num_local_head, head_dim = input.shape
input_t = input.reshape([bs, seq_world_size, global_seq_len // seq_world_size, num_local_head,
head_dim]).contiguous()
input_t = input_t.permute(1, 0, 2, 3, 4).contiguous()
else:
bs, local_seq_len, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([bs, local_seq_len, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
else:
# s, b, n, h
if scatter_idx < 2:
global_seq_len, bs, num_local_head, head_dim = input.shape
input_t = input.reshape([seq_world_size, global_seq_len // seq_world_size, bs, num_local_head,
head_dim]).contiguous()
else:
local_seq_len, bs, num_total_head, head_dim = input.shape
assert num_total_head % seq_world_size == 0, f"Number of heads ({num_total_head}) must be divisible by the sequence parallel size ({seq_world_size})!"
input_t = input.reshape([local_seq_len, bs, seq_world_size, num_total_head // seq_world_size,
head_dim]).contiguous()
input_t = input_t.permute(2, 0, 1, 3, 4).contiguous()
pre_all2all_permute_idx, pre_all2all_inp_shape, post_all2all_permute_idx, post_all2all_res_shape = _generate_layout_params(
scatter_idx, batch_dim_idx, seq_world_size, input)

if scatter_idx < 2:
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, global_seq_len, num_local_head,
head_dim)
else:
post_all2all_fun = post_all2all(scatter_idx, batch_dim_idx, seq_world_size, bs, local_seq_len, num_total_head,
head_dim)
input_t = pre_all2all_fun(pre_all2all_permute_idx, pre_all2all_inp_shape, input)

post_all2all_fun = post_all2all(post_all2all_permute_idx, post_all2all_res_shape)
output = torch.empty_like(input_t)
work = dist.all_to_all_single(output, input_t, group=group, async_op=async_op)

Expand All @@ -236,7 +248,7 @@ def single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, asyn
handle[type + '_work'] = work
handle[type + '_grad'] = output
handle[type + '_post_all2all_func'] = post_all2all_fun
return output
return output.view(post_all2all_res_shape)

res = post_all2all_fun(output)
return res
Expand Down Expand Up @@ -271,7 +283,6 @@ def forward(ctx: Any,
assert ctx.stream != None
res = single_all_to_all(input, scatter_idx, gather_idx, batch_dim_idx, group, False)
get_accelerator().current_stream().wait_stream(ctx.stream)
del ctx.stream.activation_buffer_list
# The computation of d o_weight can overlap with the communication of d o_input

elif not is_fwd and type in ('q', 'k'):
Expand Down
1 change: 0 additions & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ RUN pip install psutil \
sentencepiece \
msgpack \
requests \
pandas \
sphinx \
sphinx_rtd_theme \
scipy \
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/alexnet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def cast_to_half(x):

def cifar_trainset(fp16=False):
torchvision = pytest.importorskip("torchvision", minversion="0.5.0")
import torchvision.transforms as transforms
from torchvision import transforms

transform_list = [
transforms.ToTensor(),
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.16.3
0.16.4

0 comments on commit 4f9c2a7

Please sign in to comment.