Skip to content

Commit

Permalink
support to sync params when train_tp divides inference_tp. (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 authored Sep 26, 2024
1 parent 0557957 commit 87961e2
Show file tree
Hide file tree
Showing 12 changed files with 499 additions and 38 deletions.
158 changes: 154 additions & 4 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.utils import future
from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense
from chatlearn.utils.dist_utils import bucket_tensors_two_stage, coalesced_comm_dense_two_stage
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.global_vars import set_global_variables
from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(self, name, args=None, replica_id=0):
self._data_ckpt_manager = None
self._peak_memory = 0
self._parameters_to_sync = defaultdict(list)
self._parameters_to_send = defaultdict(list)
self._parameters_to_recv = defaultdict(list)
self._parameters_shape = []
self._concat_params_dict = None
self._to_fix_act_ordering_dict = None
self._to_fix_qkv_ordering_dict = None
Expand All @@ -124,6 +128,32 @@ def __init__(self, name, args=None, replica_id=0):
# parameter sync from src_model
self._src_parameter_model = None
self.profiler = None
self._buffer_num = {}
self._tp_division = {}
self._num_mapping = 1
self._sync_buffer = defaultdict(list)

def get_sync_buffer(self):
return self._sync_buffer

def set_num_mapping(self, _num_mapping):
self._num_mapping = _num_mapping

@property
def num_mapping(self):
return self._num_mapping

def set_buffer_num(self, buffer_num):
self._buffer_num.update(buffer_num)

def get_buffer_num(self, param_names):
return [self._buffer_num[name] for name in param_names]

def set_tp_division(self, tp_division):
self._tp_division.update(tp_division)

def get_tp_division(self, param_names):
return [self._tp_division[name] for name in param_names]

@property
def is_colocate(self):
Expand Down Expand Up @@ -621,11 +651,13 @@ def get_to_fix_qkv_ordering_func(self):
def set_to_fix_qkv_ordering_func(self, _to_fix_qkv_ordering_func):
self._to_fix_qkv_ordering_func = _to_fix_qkv_ordering_func

def set_sync_parameters(self, trainable_param_names, pipe_stage=0):
def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None):
"""
:meta private:
"""
if pipe_stage not in self._parameters_to_sync or len(self._parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks
if parameters_to_sync is None:
parameters_to_sync = self._parameters_to_sync
if pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks
concat = []
set_sync_param_flag = False

Expand Down Expand Up @@ -719,8 +751,21 @@ def set_sync_parameters(self, trainable_param_names, pipe_stage=0):
_params_to_sync = _params_to_sync.contiguous()
concat = []
set_sync_param_flag = False
self._parameters_to_sync[pipe_stage].append(_params_to_sync)
parameters_to_sync[pipe_stage].append((name, _params_to_sync))

def set_send_parameters(self, trainable_param_names, pipe_stage=0):
"""
:meta private:
"""
return self.set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_send)

def set_recv_parameters(self, rank, trainable_param_names, pipe_stage=0):
"""
:meta private:
"""
parameters_to_recv = defaultdict(list)
self._parameters_to_recv[rank] = parameters_to_recv
return self.set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_recv)

def get_parameter_names(self, requires_grad=True):
"""
Expand All @@ -732,6 +777,15 @@ def get_parameter_names(self, requires_grad=True):
else:
return [param_to_name[param] for param in self.parameters]

def get_parameter_shape(self, param_names):
"""
:meta private:
"""
parameters_shape = []
for name in param_names:
parameters_shape.append((name, self.named_parameters[name].shape))
return parameters_shape

def get_parameter(self, name):
"""
:meta private:
Expand Down Expand Up @@ -774,7 +828,7 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
"""
:meta private:
"""
tensors = [param.data for param in self._parameters_to_sync[pipe_stage]]
tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]]
assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
Expand All @@ -786,6 +840,102 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
"""
:meta private:
"""
tensor_changed = rank != src_rank

if stage2:
if tensor_changed:
parameters_to_sync = self._parameters_to_recv[rank]
else:
parameters_to_sync = self._parameters_to_send
else:
del self._sync_buffer
self._sync_buffer = defaultdict(list)
parameters_to_sync = self._parameters_to_sync

tensors = []
buffer_num = []
if stage2 and not tensor_changed and self._sync_buffer:
idx = 0
for name, param in parameters_to_sync[pipe_stage]:
tensors.append(self._sync_buffer[(to_rank + 1) % self.num_mapping][idx])
buffer_num.append(1)
idx += 1
del self._sync_buffer[(to_rank + 1) % self.num_mapping]
else:
for name, param in parameters_to_sync[pipe_stage]:
param_data = param.data
param_data_shape = param_data.shape
if rank and self._buffer_num and not stage2:
assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
buffer_num.append(self._buffer_num[name])
elif stage2:
buffer_num.append(1)
else:
if "attention.query_key_value" in name or "self_attention.query_key_value" in name:
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]
param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list

if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name:
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + self._tp_division[name]) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
buffer_num.append(1)
tensors.append(param_data)

assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors_two_stage(
tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb,
buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)

for bucket in dense_buckets:
index = 0 if stage2 else (to_rank % self.num_mapping)
all_buffers = coalesced_comm_dense_two_stage(
bucket, col.broadcast, rank,
extra_args=(src_rank, group_name), tensor_changed=tensor_changed,
stage2=stage2, index=index)
if tensor_changed and not stage2:
for key, value in all_buffers.items():
self._sync_buffer[key] += value

for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

self.empty_cache()
return self._sync_buffer


def send_parameter(self, name, dst_rank, group_name, pipe_stage=0):
"""
Expand Down
11 changes: 6 additions & 5 deletions chatlearn/models/vllm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,9 +326,10 @@ def empty_cache(self):
self.worker.cache_engine.cpu_cache = None
self.worker.cache_engine.gpu_cache = None
elif CURRENT_VLLM_VERSION == VLLMVersion.v_0_5_1.value:
for ele in self.worker.gpu_cache: # pylint: disable=unused-variable
ele = None
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition
if self.worker.gpu_cache is not None:
for ele in self.worker.gpu_cache: # pylint: disable=unused-variable
ele = None
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition

for c_e in self.worker.cache_engine:
c_e.cpu_cache = None
Expand Down Expand Up @@ -573,14 +574,14 @@ def data_parallel_size(self):
"""
:meta private:
"""
return None
return 1

@property
def data_parallel_rank(self):
"""
:meta private:
"""
return None
return 0

def tensor_parallel_rank(self):
"""
Expand Down
Loading

0 comments on commit 87961e2

Please sign in to comment.