Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support to sync params when train_tp divides inference_tp. #92

Merged
merged 13 commits into from
Sep 26, 2024
164 changes: 160 additions & 4 deletions chatlearn/models/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from chatlearn.checkpoint.checkpoint_manager import CheckpointManager
from chatlearn.utils import future
from chatlearn.utils.dist_utils import bucket_tensors, coalesced_comm_dense
from chatlearn.utils.dist_utils import bucket_tensors_two_stage, coalesced_comm_dense_two_stage
from chatlearn.utils.global_vars import get_args
from chatlearn.utils.global_vars import set_global_variables
from chatlearn.utils.logger import log_rank_0, debug_rank_0, setup_logger
Expand Down Expand Up @@ -106,6 +107,9 @@ def __init__(self, name, args=None, replica_id=0):
self._data_ckpt_manager = None
self._peak_memory = 0
self._parameters_to_sync = defaultdict(list)
self._parameters_to_send = defaultdict(list)
self._parameters_to_recv = defaultdict(list)
self._parameters_shape = []
self._concat_params_dict = None
self._to_fix_act_ordering_dict = None
self._to_fix_qkv_ordering_dict = None
Expand All @@ -124,6 +128,32 @@ def __init__(self, name, args=None, replica_id=0):
# parameter sync from src_model
self._src_parameter_model = None
self.profiler = None
self._buffer_num = {}
self._tp_division = {}
self._num_mapping = 1
self._sync_buffer = defaultdict(list)

def get_sync_buffer(self):
return self._sync_buffer

def set_num_mapping(self, _num_mapping):
self._num_mapping = _num_mapping

@property
def num_mapping(self):
return self._num_mapping

def set_buffer_num(self, buffer_num):
self._buffer_num.update(buffer_num)

def get_buffer_num(self, param_names):
return [self._buffer_num[name] for name in param_names]

def set_tp_division(self, tp_division):
self._tp_division.update(tp_division)

def get_tp_division(self, param_names):
return [self._tp_division[name] for name in param_names]

@property
def is_colocate(self):
Expand Down Expand Up @@ -621,11 +651,13 @@ def get_to_fix_qkv_ordering_func(self):
def set_to_fix_qkv_ordering_func(self, _to_fix_qkv_ordering_func):
self._to_fix_qkv_ordering_func = _to_fix_qkv_ordering_func

def set_sync_parameters(self, trainable_param_names, pipe_stage=0):
def set_sync_parameters(self, trainable_param_names, pipe_stage=0, parameters_to_sync=None):
"""
:meta private:
"""
if pipe_stage not in self._parameters_to_sync or len(self._parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks
if parameters_to_sync is None:
parameters_to_sync = self._parameters_to_sync
adoda marked this conversation as resolved.
Show resolved Hide resolved
if pipe_stage not in parameters_to_sync or len(parameters_to_sync[pipe_stage]) == 0: # pylint: disable=too-many-nested-blocks
concat = []
set_sync_param_flag = False

Expand Down Expand Up @@ -719,8 +751,21 @@ def set_sync_parameters(self, trainable_param_names, pipe_stage=0):
_params_to_sync = _params_to_sync.contiguous()
concat = []
set_sync_param_flag = False
self._parameters_to_sync[pipe_stage].append(_params_to_sync)
parameters_to_sync[pipe_stage].append((name, _params_to_sync))
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved

def set_send_parameters(self, trainable_param_names, pipe_stage=0):
"""
:meta private:
"""
return self.set_sync_parameters(trainable_param_names, pipe_stage, self._parameters_to_send)

def set_recv_parameters(self, rank, trainable_param_names, pipe_stage=0):
"""
:meta private:
"""
parameters_to_recv = defaultdict(list)
self._parameters_to_recv[rank] = parameters_to_recv
return self.set_sync_parameters(trainable_param_names, pipe_stage, parameters_to_recv)
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved

def get_parameter_names(self, requires_grad=True):
"""
Expand All @@ -732,6 +777,15 @@ def get_parameter_names(self, requires_grad=True):
else:
return [param_to_name[param] for param in self.parameters]

def get_parameter_shape(self, param_names):
"""
:meta private:
"""
parameters_shape = []
for name in param_names:
parameters_shape.append((name, self.named_parameters[name].shape))
return parameters_shape

def get_parameter(self, name):
"""
:meta private:
Expand Down Expand Up @@ -774,7 +828,7 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
"""
:meta private:
"""
tensors = [param.data for param in self._parameters_to_sync[pipe_stage]]
tensors = [param.data for _, param in self._parameters_to_sync[pipe_stage]]
assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors(tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)
Expand All @@ -786,6 +840,108 @@ def broadcast_parameter(self, rank, src_rank, group_name, pipe_stage=0):
for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

def broadcast_parameter_two_stage(self, to_rank, rank, src_rank, group_name, pipe_stage=0, stage2=False):
charles9304 marked this conversation as resolved.
Show resolved Hide resolved
"""
:meta private:
"""
tensor_changed = rank != src_rank

if stage2:
if tensor_changed:
parameters_to_sync = self._parameters_to_recv[rank]
else:
parameters_to_sync = self._parameters_to_send
else:
del self._sync_buffer
self._sync_buffer = defaultdict(list)
parameters_to_sync = self._parameters_to_sync

tensors = []
buffer_num = []
if stage2 and not tensor_changed and self._sync_buffer:
idx = 0
for name, param in parameters_to_sync[pipe_stage]:
assert self.num_mapping == 2
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
tensors.append(self._sync_buffer[(to_rank + 1) % self.num_mapping][idx])
buffer_num.append(1)
idx += 1
del self._sync_buffer[(to_rank + 1) % self.num_mapping]
else:
for name, param in parameters_to_sync[pipe_stage]:
param_data = param.data
param_data_shape = param_data.shape
if rank and self._buffer_num and not stage2:
assert name in self._buffer_num, f"{name} in self._buffer_num for rank {rank}"
buffer_num.append(self._buffer_num[name])
elif stage2:
buffer_num.append(1)
else:
if "attention.query_key_value" in name or "self_attention.query_key_value" in name:
haolin-nju marked this conversation as resolved.
Show resolved Hide resolved
tp_size = self.module_args.args_dict["tensor_model_parallel_size"]
heads = self.module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.module_args.args_dict["hidden_size"] // self.module_args.args_dict["num_attention_heads"]
param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list

if "self_attention.dense" in name or "mlp.dense_4h_to_h" in name:
param_data_list = []
col_offset = param_data_shape[1] // self._tp_division[name]
for idx in range(self._tp_division[name]):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
if "mlp.dense_h_to_4h" in name:
haolin-nju marked this conversation as resolved.
Show resolved Hide resolved
param_data_list = []
row_offset = param_data_shape[0] // self._tp_division[name] // 2
for idx in range(self._tp_division[name]):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + self._tp_division[name]) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
buffer_num.append(1)
tensors.append(param_data)

assert len(tensors) > 0
dense_buckets, sparse_bucket = bucket_tensors_two_stage(
tensors, bucket_size_mb=self.runtime_args.coalesced_buffer_mb,
buffer_num=None if stage2 else buffer_num, tensor_changed=tensor_changed and not stage2)
debug_rank_0(f"{self.name} Got dense_buckets {len(dense_buckets)}, spase_bucket {len(sparse_bucket)}", self._logger)

for bucket in dense_buckets:
index = 0 if stage2 else (to_rank % self.num_mapping)
# if not stage2:
# raise RuntimeError(f"self.num_mapping: {self.num_mapping}")
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
all_buffers = coalesced_comm_dense_two_stage(
bucket, col.broadcast, rank,
extra_args=(src_rank, group_name), tensor_changed=tensor_changed,
stage2=stage2, index=index)
if tensor_changed and not stage2:
for key, value in all_buffers.items():
self._sync_buffer[key] += value

for param in sparse_bucket:
col.broadcast(param, src_rank, group_name)

try:
self.empty_cache()
SeaOfOcean marked this conversation as resolved.
Show resolved Hide resolved
except Exception as e:
return {"error_message": e}
return self._sync_buffer


def send_parameter(self, name, dst_rank, group_name, pipe_stage=0):
"""
Expand Down
4 changes: 2 additions & 2 deletions chatlearn/models/vllm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,14 @@ def data_parallel_size(self):
"""
:meta private:
"""
return None
return 1

@property
def data_parallel_rank(self):
"""
:meta private:
"""
return None
return 0
adoda marked this conversation as resolved.
Show resolved Hide resolved

def tensor_parallel_rank(self):
"""
Expand Down
Loading
Loading