Skip to content

Commit 47a9b76

Browse files
committed
Merge branch 'main' into release/3.2
2 parents a18500b + 9508674 commit 47a9b76

File tree

15 files changed

+148
-58
lines changed

15 files changed

+148
-58
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ Running Environment:
125125
| peft | >=0.11,<0.16 | ||
126126
| trl | >=0.13,<0.17 | 0.16 |RLHF|
127127
| deepspeed | >=0.14 | 0.14.5 | Training |
128-
| vllm | >=0.5.1 | 0.7.3 | Inference/Deployment/Evaluation |
128+
| vllm | >=0.5.1,<0.8 | 0.7.3 | Inference/Deployment/Evaluation |
129129
| lmdeploy | >=0.5 | 0.7.2.post1 | Inference/Deployment/Evaluation |
130130
| evalscope | >=0.11 | | Evaluation |
131131

README_CN.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ pip install -e .
120120
| peft | >=0.11,<0.16 | ||
121121
| trl | >=0.13,<0.17 | 0.16 |RLHF|
122122
| deepspeed | >=0.14 | 0.14.5 |训练|
123-
| vllm | >=0.5.1 | 0.7.3 |推理/部署/评测|
123+
| vllm | >=0.5.1,<0.8 | 0.7.3 |推理/部署/评测|
124124
| lmdeploy | >=0.5 | 0.7.2.post1 |推理/部署/评测|
125125
| evalscope | >=0.11 | |评测|
126126

docs/source/BestPractices/Embedding训练.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/ma
6363
{"query": "<image>sentence1", "response": "sentence2", "images": "/some/images.jpg", "label": 0}
6464
```
6565

66+
评测的指标分别是两个embedding的欧式距离、点积等的pearson系数以及spearman系数,共八个指标。
67+
6668
### infonce 格式
6769

6870
```json lines
@@ -82,6 +84,11 @@ infonce loss支持几个环境变量:
8284
> 也可以在数据集中将hard negatives数量设置为数量相等,这样即使不设置也不会使用for循环方式,加快计算速度
8385
> rejected_response也可以没有,这种情况下INFONCE_USE_BATCH保持为True,会使用一个batch内部的其他samples作为rejected responses
8486
87+
infonce loss的评测会有下面几个指标:
88+
- mean_neg 所有hard_negative的平均值
89+
- mean_pos 所有positive的平均值
90+
- margin positive-max_hard_negative的平均值
91+
8592
## 脚手架
8693

8794
SWIFT提供了两个脚手架训练脚本:

docs/source/GetStarted/SWIFT安装.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ pip install ms-swift==2.*
6565
| peft | >=0.11,<0.16 | ||
6666
| trl | >=0.13,<0.17 | 0.16 |RLHF|
6767
| deepspeed | >=0.14 | 0.14.5 |训练|
68-
| vllm | >=0.5.1 | 0.7.3 |推理/部署/评测|
68+
| vllm | >=0.5.1,<0.8 | 0.7.3 |推理/部署/评测|
6969
| lmdeploy | >=0.5 | 0.7.2.post1 |推理/部署/评测|
7070
| evalscope | >=0.11 | |评测|
7171

docs/source/Instruction/GRPO.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ A conversation between User and Assistant. The user asks a question, and the Ass
116116
- offload_optimizer: 是否在vLLM/LMDeploy推理时offload optimizer参数,默认为False
117117
- offload_model: 是否在vLLM/LMDeploy推理时offload 模型本身,默认为False
118118
- gc_collect_after_offload: 是否在offload结束时进行gc(python gc和GPU gc),默认为False
119-
- mini_batch_size:用于将每个设备上的批次大小(per_device_batch)进一步切分为更小的子批次。为确保切分有效,per_device_batch 需要能够被 mini_batch_size 整除。
120-
119+
- multi_turn_func: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现
120+
- mini_batch_size:用于将每个设备上的批次大小(per_device_batch)进一步切分为更小的子批次。为确保切分有效,per_device_batch 需要能够被 mini_batch_size 整除
121121
122122
奖励函数超参,见[内置奖励函数](#内置奖励函数)
123123

docs/source/Instruction/命令行参数.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,8 @@ reward模型参数将在PPO、GRPO中使用。
411411
- offload_optimizer: 是否在vLLM/LMDeploy推理时offload optimizer参数,默认为False
412412
- offload_model: 是否在vLLM/LMDeploy推理时offload 模型本身,默认为False
413413
- gc_collect_after_offload: 是否在offload结束时进行gc(python gc和GPU gc),默认为False
414-
- mini_batch_size:用于将每个设备上的批次大小(per_device_batch)进一步切分为更小的子批次。为确保切分有效,per_device_train_batch_size 需要能够被 mini_batch_size 整除。
414+
- multi_turn_func: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现
415+
- mini_batch_size:用于将每个设备上的批次大小(per_device_batch)进一步切分为更小的子批次。为确保切分有效,per_device_train_batch_size 需要能够被 mini_batch_size 整除
415416

416417
cosine 奖励参数
417418
- cosine_min_len_value_wrong:cosine 奖励函数参数,生成错误答案时,最小长度对应的奖励值。默认值为0.0

docs/source_en/BestPractices/Embedding.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ The source code for the loss functions can be found [here](https://github.com/mo
5252
{"query": "sentence1", "response": "<image>sentence2", "images": ["/some/images1.jpg"], "label": 0.7}
5353
```
5454

55+
The eval metrics are the Pearson and Spearman's Rank Correlation Coefficient of the embeddings' euclidean distance/dot production and so on, totally 8 values.
56+
5557
### Format for Contrastive/Online Contrastive Loss
5658

5759
```json lines
@@ -82,6 +84,10 @@ InfoNCE loss supports the following environment variables:
8284
>
8385
> `rejected_response` can also be omitted. In this case, `INFONCE_USE_BATCH` remains `True` and will use other samples within the batch as rejected responses.
8486
87+
The evaluation of InfoNCE loss includes the following metrics:
88+
- mean_neg: The average of all hard negatives
89+
- mean_pos: The average of all positives
90+
- margin: The average of (positive - max hard negative)
8591

8692
## Scaffolding
8793

docs/source_en/GetStarted/SWIFT-installation.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ You can view the image [here](https://modelscope.cn/docs/intro/environment-setup
6666
| peft | >=0.11,<0.16 | | |
6767
| trl | >=0.13,<0.17 | 0.16 | RLHF |
6868
| deepspeed | >=0.14 | 0.14.5 | Training |
69-
| vllm | >=0.5.1 | 0.7.3 | Inference/Deployment/Evaluation |
69+
| vllm | >=0.5.1,<0.8 | 0.7.3 | Inference/Deployment/Evaluation |
7070
| lmdeploy | >=0.5 | 0.7.2.post1 | Inference/Deployment/Evaluation |
7171
| evalscope | >=0.11 | | Evaluation |
7272

docs/source_en/Instruction/Command-line-parameters.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ The meanings of the following parameters can be referenced [here](https://huggin
422422
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM/LMDeploy. The default is `False`.
423423
- offload_model: Whether to offload the model itself during inference with vLLM/LMDeploy. The default is `False`.
424424
- gc_collect_after_offload: Whether to perform garbage collection (both Python GC and GPU GC) after offloading. The default is `False`.
425-
- mini_batch_size: Used to further split the batch size on each device (per_device_batch) into smaller sub-batches. To ensure the split is valid, per_device_train_batch_size needs be divisible by mini_batch_size.
425+
- multi_turn_func: The multi turn GRPO plugin name. Add your multi-turn implementation in plugin/multi_turn.py
426+
- mini_batch_size: Used to further split the batch size on each device (per_device_batch) into smaller sub-batches. To ensure the split is valid, per_device_train_batch_size needs be divisible by mini_batch_size
427+
426428
cosine reward function arguments
427429
- `cosine_min_len_value_wrong` (default: 0.0): Reward value corresponding to the minimum length when the answer is incorrect. Default is 0.0
428430
- `cosine_max_len_value_wrong` (default: -0.5): Reward value corresponding to the maximum length when the answer is incorrect. Default is -0.5

docs/source_en/Instruction/GRPO.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ Hyperparameters
118118
- offload_optimizer: Whether to offload optimizer parameters during inference with vLLM/LMDeploy. The default is `False`.
119119
- offload_model: Whether to offload the model itself during inference with vLLM/LMDeploy. The default is `False`.
120120
- gc_collect_after_offload: Whether to perform garbage collection (both Python GC and GPU GC) after offloading. The default is `False`.
121-
- mini_batch_size: Used to further split the batch size on each device (per_device_batch) into smaller sub-batches. To ensure the split is valid, per_device_train_batch_size needs be divisible by mini_batch_size.
122-
121+
- multi_turn_func: The multi turn GRPO plugin name. Add your multi-turn implementation in plugin/multi_turn.py
122+
- mini_batch_size: Used to further split the batch size on each device (per_device_batch) into smaller sub-batches. To ensure the split is valid, per_device_train_batch_size needs be divisible by mini_batch_size
123123

124124
The hyperparameters for the reward function can be found in the [Built-in Reward Functions section](#built-in-reward-functions).
125125

requirements/install_all.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# please use python=3.10, cuda12.*
22
# sh requirements/install_all.sh
3-
pip install "vllm>=0.5.1" -U
3+
pip install "vllm>=0.5.1,<0.8" -U
44
pip install "lmdeploy>=0.5" -U --no-deps
55
pip install autoawq -U --no-deps
66
pip install auto_gptq optimum bitsandbytes -U

swift/llm/infer/infer_engine/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,12 +470,12 @@ def new_group_context():
470470

471471
@contextmanager
472472
def set_device_context(device: Union[str, int]):
473-
original_device = torch.cuda.current_device()
474-
torch.cuda.set_device(device)
473+
origin_device = get_current_device()
474+
set_device(device)
475475
try:
476476
yield
477477
finally:
478-
torch.cuda.set_device(original_device)
478+
set_device(origin_device)
479479

480480

481481
@contextmanager

swift/plugin/loss.py

Lines changed: 112 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -
8383

8484

8585
def _parse_pair_sentence(outputs):
86-
last_hidden_state = outputs['last_hidden_state']
86+
if isinstance(outputs, dict):
87+
last_hidden_state = outputs['last_hidden_state']
88+
else:
89+
last_hidden_state = outputs
8790
batch_size = last_hidden_state.shape[0]
8891
shape_len = len(last_hidden_state.shape)
8992
first_sentence = list(range(0, batch_size, 2))
@@ -126,6 +129,114 @@ def contrastive_loss(outputs, labels, loss_scale=None, num_items_in_batch=None)
126129
return losses.mean()
127130

128131

132+
def calculate_paired_metrics(embeddings, labels):
133+
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \
134+
paired_manhattan_distances
135+
from scipy.stats import pearsonr, spearmanr
136+
137+
embeddings1, embeddings2 = _parse_pair_sentence(embeddings)
138+
cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
139+
manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
140+
euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
141+
dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
142+
143+
eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)
144+
eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)
145+
146+
eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)
147+
eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)
148+
149+
eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)
150+
eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)
151+
152+
eval_pearson_dot, _ = pearsonr(labels, dot_products)
153+
eval_spearman_dot, _ = spearmanr(labels, dot_products)
154+
155+
return {
156+
'pearson_cosine': eval_pearson_cosine,
157+
'pearson_euclidean': eval_pearson_manhattan,
158+
'pearson_manhattan': eval_pearson_euclidean,
159+
'pearson_dot_product': eval_pearson_dot,
160+
'spearman_cosine': eval_spearman_cosine,
161+
'spearman_euclidean': eval_spearman_manhattan,
162+
'spearman_manhattan': eval_spearman_euclidean,
163+
'spearman_dot_product': eval_spearman_dot,
164+
}
165+
166+
167+
def calculate_infonce_metrics(embeddings, labels):
168+
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \
169+
paired_manhattan_distances
170+
from scipy.stats import pearsonr, spearmanr
171+
hard_negatives = os.environ.get('INFONCE_HARD_NEGATIVES', None)
172+
use_batch = strtobool(os.environ.get('INFONCE_USE_BATCH', 'True'))
173+
split_tensors = _parse_multi_negative_sentences(torch.tensor(embeddings), torch.tensor(labels), hard_negatives)
174+
split_tensors = [t.numpy() for t in split_tensors]
175+
can_batched = hard_negatives is not None
176+
if hard_negatives is None and len(set([s.shape[0] for s in split_tensors])) == 1:
177+
can_batched = True
178+
all_similarity_matrix = []
179+
all_labels = []
180+
pos_neg_margins = []
181+
if not use_batch:
182+
if can_batched:
183+
sentences = np.stack(split_tensors, axis=0)
184+
similarity_matrix = np.matmul(sentences[:, 0:1], sentences[:, 1:].transpose((0, 2, 1))).squeeze(1)
185+
all_similarity_matrix.append(similarity_matrix)
186+
labels = np.zeros_like(similarity_matrix)
187+
labels[:, 0] = 1
188+
all_labels.append(labels)
189+
else:
190+
for tensor in split_tensors:
191+
similarity_matrix = np.matmul(tensor[0], tensor[1:].T)
192+
all_similarity_matrix.append(similarity_matrix)
193+
labels = np.zeros_like(similarity_matrix)
194+
labels[0] = 1
195+
all_labels.append(labels)
196+
max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1)
197+
pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item())
198+
else:
199+
if can_batched:
200+
sentences = np.stack(split_tensors, axis=0)
201+
similarity_matrix = np.matmul(sentences[:, 0], sentences[:, 1:].reshape(-1, sentences.shape[2]).T)
202+
all_similarity_matrix.append(similarity_matrix)
203+
labels = np.zeros_like(similarity_matrix)
204+
for row, col in enumerate(range(0, sentences.shape[0] * (sentences.shape[1] - 1), sentences.shape[1] - 1)):
205+
labels[row, col] = 1
206+
all_labels.append(labels)
207+
else:
208+
all_tensors = []
209+
for tensor in split_tensors:
210+
all_tensors.append(tensor[1:])
211+
sentences = np.concatenate(all_tensors, axis=0)
212+
length = 0
213+
for idx, tensor in enumerate(split_tensors):
214+
similarity_matrix = np.matmul(tensor[0], sentences.T)
215+
all_similarity_matrix.append(similarity_matrix)
216+
labels = np.zeros_like(similarity_matrix)
217+
labels[length] = 1
218+
all_labels.append(labels)
219+
length += tensor.shape[0] - 1
220+
max_neg_scores = np.max(similarity_matrix[labels == 0], axis=-1)
221+
pos_neg_margins.append(np.mean(similarity_matrix[labels == 1] - max_neg_scores).item())
222+
223+
similarity_matrix = np.concatenate(all_similarity_matrix, axis=0)
224+
labels = np.concatenate(all_labels, axis=0)
225+
if can_batched:
226+
pos_scores = similarity_matrix[labels == 1].reshape(similarity_matrix.shape[0], -1)
227+
neg_scores = similarity_matrix[labels == 0].reshape(similarity_matrix.shape[0], -1)
228+
max_neg_scores = np.max(neg_scores, axis=-1)
229+
pos_neg_margin = np.mean(pos_scores - max_neg_scores).item()
230+
else:
231+
pos_scores = similarity_matrix[labels == 1]
232+
neg_scores = similarity_matrix[labels == 0]
233+
pos_neg_margin = np.mean(pos_neg_margins)
234+
235+
mean_neg = np.mean(neg_scores)
236+
mean_pos = np.mean(pos_scores)
237+
return {'margin': pos_neg_margin, 'mean_neg': mean_neg, 'mean_pos': mean_pos}
238+
239+
129240
def _parse_multi_negative_sentences(sentences, labels, hard_negatives=None):
130241
split_indices = torch.nonzero(labels, as_tuple=False).squeeze().tolist()
131242
if isinstance(split_indices, int):

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def reorder_outputs(outputs, distributed_idx):
630630

631631
return [index_to_output[idx] for idx in sorted(index_to_output.keys())]
632632

633-
def _infer_multi_turn(self, inputs_slice, request_config) -> List[List[Dict[str, Any]]]:
633+
def _infer_multi_turn(self, inputs_slice, request_config) -> List[List[List[Dict[str, Any]]]]:
634634
from swift.llm.infer.protocol import ChatCompletionResponse
635635
rank, _, _, _ = get_dist_setting()
636636
request_config = copy(request_config)

swift/trainers/trainers.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -71,48 +71,11 @@ def __init__(self, *args, **kwargs):
7171
self.label_names = ['labels']
7272

7373
def calculate_metric(self, eval_prediction: EvalPrediction) -> Dict[str, float]:
74-
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, \
75-
paired_manhattan_distances
76-
from scipy.stats import pearsonr, spearmanr
77-
78-
embeddings = eval_prediction.predictions
79-
labels = eval_prediction.label_ids
80-
batch_size = 2 * self.args.per_device_eval_batch_size
81-
half_batch_size = self.args.per_device_eval_batch_size
82-
embeddings1 = []
83-
embeddings2 = []
84-
for i in range(embeddings.shape[0] // batch_size):
85-
embeddings1.append(embeddings[i * batch_size:i * batch_size + half_batch_size])
86-
embeddings2.append(embeddings[i * batch_size + half_batch_size:(i + 1) * batch_size])
87-
88-
embeddings1 = np.concatenate(embeddings1)
89-
embeddings2 = np.concatenate(embeddings2)
90-
if len(embeddings1.shape) == 3:
91-
embeddings1 = embeddings1[:, 0]
92-
embeddings2 = embeddings2[:, 0]
93-
cosine_scores = 1 - (paired_cosine_distances(embeddings1, embeddings2))
94-
manhattan_distances = -paired_manhattan_distances(embeddings1, embeddings2)
95-
euclidean_distances = -paired_euclidean_distances(embeddings1, embeddings2)
96-
dot_products = [np.dot(emb1, emb2) for emb1, emb2 in zip(embeddings1, embeddings2)]
97-
98-
eval_pearson_cosine, _ = pearsonr(labels, cosine_scores)
99-
eval_spearman_cosine, _ = spearmanr(labels, cosine_scores)
100-
101-
eval_pearson_manhattan, _ = pearsonr(labels, manhattan_distances)
102-
eval_spearman_manhattan, _ = spearmanr(labels, manhattan_distances)
103-
104-
eval_pearson_euclidean, _ = pearsonr(labels, euclidean_distances)
105-
eval_spearman_euclidean, _ = spearmanr(labels, euclidean_distances)
106-
107-
eval_pearson_dot, _ = pearsonr(labels, dot_products)
108-
eval_spearman_dot, _ = spearmanr(labels, dot_products)
109-
110-
return {
111-
'cosine': eval_spearman_cosine,
112-
'euclidean': eval_pearson_euclidean,
113-
'manhattan': eval_pearson_manhattan,
114-
'dot_product': eval_spearman_dot,
115-
}
74+
from swift.plugin.loss import infonce_loss, calculate_paired_metrics, calculate_infonce_metrics
75+
if self.compute_loss_func is infonce_loss:
76+
return calculate_infonce_metrics(eval_prediction.predictions, eval_prediction.label_ids)
77+
else:
78+
return calculate_paired_metrics(eval_prediction.predictions, eval_prediction.label_ids)
11679

11780

11881
class Seq2SeqTrainer(SwiftMixin, HfSeq2SeqTrainer):

0 commit comments

Comments
 (0)