Skip to content

Commit aceef3a

Browse files
authored
fix: example of scifact llm embedding
1 parent 7747579 commit aceef3a

17 files changed

+177
-46
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
| Exp | Model | Original | Finetuned | Demo |
4444
|-------------------------------|-------------------------|----------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
4545
| **embed** pairwise finetune | bge-base-zh-v1.5 | 0.657 | **0.703** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
46-
| **embed** LLM finetune (LoRA) | Qwen2-1.5B-Instruct | 0.546 | **0.694** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
46+
| **embed** LLM finetune (LoRA) | Qwen2-1.5B-Instruct | 0.546 | **0.695** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
4747
| **rerank** cross encoder | bge-reranker-base | 0.666 | **0.706** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing) |
4848
| **rerank** colbert | chinese-roberta-wwm-ext | 0.643 | **0.687** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |
4949
| **rerank** LLM (LoRA) | Qwen2-1.5B-Instruct | 0.531 | **0.699** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing) |

README_ja-JP.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
| Exp | Model | Size | Original | Finetuned | Demo |
4343
|-------------------------------|-------------------------|------|----------|-----------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------|
4444
| **embed** pairwise finetune | bge-base-zh-v1.5 | - | 0.657 | **0.703** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
45-
| **embed** LLM finetune (LoRA) | Qwen2-1.5B-Instruct | - | 0.546 | **0.694** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
45+
| **embed** LLM finetune (LoRA) | Qwen2-1.5B-Instruct | - | 0.546 | **0.695** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
4646
| **rerank** cross encoder | bge-reranker-base | - | 0.666 | **0.706** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing) |
4747
| **rerank** colbert | chinese-roberta-wwm-ext | - | 0.643 | **0.687** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |
4848
| **rerank** LLM (LoRA) | Qwen2-1.5B-Instruct | - | 0.531 | **0.699** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing) |

README_zh-CN.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
- 支持全套重排微调,cross encoder、ColBERT、LLM
4040
- 支持定制化RAG,支持在Transformers、Langchain、LlamaIndex中便捷使用微调后的模型
4141

42-
| 实验 | 模型 | 尺寸| 原分数 | 微调分数 | Demo代码 |
43-
|-----------------------|-------------------------|----|-------|-----------|-------------------------------------------------------------------------------------------------------------------------------------|
42+
| 实验 | 模型 | 尺寸| 原分数 | 微调分数 | Demo代码 |
43+
|----------------------|-------------------------|----|-------|-----------|-------------------------------------------------------------------------------------------------------------------------------------|
4444
| pairwise微调**向量** | bge-base-zh-v1.5 | - | 0.657 | **0.703** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/17KXe2lnNRID-HiVvMtzQnONiO74oGs91?usp=sharing) |
45-
| 大模型LoRA微调**向量** | Qwen2-1.5B-Instruct | - | 0.546 | **0.694** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
45+
| 大模型LoRA微调**向量** | Qwen2-1.5B-Instruct | - | 0.546 | **0.695** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1jj1kBQWFcuQ3a7P9ttnl1hgX7H8WA_Za?usp=sharing) |
4646
| cross encoder**重排** | bge-reranker-base | - | 0.666 | **0.706** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing) |
4747
| colbert**重排** | chinese-roberta-wwm-ext | - | 0.643 | **0.687** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing) |
4848
| LLM**重排** | Qwen2-1.5B-Instruct | - | 0.531 | **0.699** | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing) |

examples/README.md

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
# Open-Retrievals examples
22

3-
## Basic Usage
4-
53
- [embedding-pairwise finetune](./embedding_pairwise_finetune.py)
64
- [embedding-llm pairwise finetune](./embedding_llm_finetune.py)
75
- [rerank-cross encoder](./rerank_cross_encoder.py)
86
- [rerank-colbert](./rerank_colbert.py)
97
- [rerank-llm finetune](rerank_llm_finetune.py)
108
- [RAG with Langchain](./rag_langchain_demo.py)
119

10+
Check the whole pipeline
11+
- [t2-ranking dataset](./t2_ranking/README.md)
12+
- [scifact dataset](./scifact/README.md)
13+
1214

1315
## Embedding
1416

examples/scifact/README.md

+26-9
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,39 @@
11
# scifact
2+
3+
| Model | mrr@10 | recall@10 | ndcg@10 |
4+
|------------------------|--------|-----------|---------|
5+
| bge-base-en-v1.5 | 0.703 | 0.862 | 0.744 |
6+
| + **fine-tuning** | 0.757 | 0.900 | 0.793 |
7+
| e5-mistral-7b-instruct | 0.589 | 0.748 | 0.630 |
8+
| + **fine-tuning** | 0.763 | 0.940 | 0.806 |
9+
10+
11+
## Fine-tuning embedding
212
- [scifact data](https://huggingface.co/datasets/Tevatron/scifact)
313
- [scifact corpus](https://huggingface.co/datasets/Tevatron/scifact-corpus)
414

5-
## Fine-tuning Embedding
615
```shell
716
sh embed_pairwuse_train.sh
817
```
918

19+
Optional: llm embedding
20+
```shell
21+
sh embed_llm_train.sh
22+
```
23+
1024
## Encoding corpus
1125
- save the pair of `(embedding vector, id)` for each corpus example, support for multiple files
26+
- for llm embed encoding, remember to use the same instruction
1227

1328
```shell
1429
sh encode_corpus.sh
1530
```
1631

32+
Optional: llm encoding
33+
```shell
34+
sh encode_llm_corpus.sh
35+
```
36+
1737
## Encoding query
1838
- save the pair of `(embedding vector, id)` for each query example
1939
- use `Tevatron/scifact/dev` or `Tevatron/scifact/test` so we can choose to encode the dev or test file
@@ -22,6 +42,11 @@ sh encode_corpus.sh
2242
sh encode_query.sh
2343
```
2444

45+
Optional: llm encoding
46+
```shell
47+
sh encode_llm_query.sh
48+
```
49+
2550
## Retrieval
2651
```shell
2752
sh retrieve.sh
@@ -37,11 +62,3 @@ sh rerank.sh
3762
```shell
3863
python evaluate.py
3964
```
40-
41-
```
42-
{
43-
"mrr@10": 0.7567949735449735,
44-
"recall@10": 0.9002222222222223,
45-
"ndcg@10": 0.7927846698591741
46-
}
47-
```

examples/scifact/embed_llm_train.sh

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
MODEL_NAME="intfloat/e5-mistral-7b-instruct"
2+
TRAIN_DATA="Tevatron/scifact"
3+
OUTPUT_DIR="./scifact/ft_llm_out"
4+
5+
torchrun --nproc_per_node 1 \
6+
-m retrievals.pipelines.embed \
7+
--output_dir $OUTPUT_DIR \
8+
--overwrite_output_dir \
9+
--model_name_or_path $MODEL_NAME \
10+
--pooling_method last \
11+
--do_train \
12+
--data_name_or_path $TRAIN_DATA \
13+
--positive_key positive_passages \
14+
--negative_key negative_passages \
15+
--use_lora True \
16+
--query_instruction "Retrieve the possible answer for query.\nQuery: " \
17+
--document_instruction 'Document: ' \
18+
--learning_rate 3e-5 \
19+
--bf16 \
20+
--num_train_epochs 4 \
21+
--per_device_train_batch_size 2 \
22+
--gradient_accumulation_steps 16 \
23+
--dataloader_drop_last True \
24+
--query_max_length 64 \
25+
--document_max_length 256 \
26+
--train_group_size 2 \
27+
--logging_strategy steps \
28+
--logging_steps 100 \
29+
--temperature 0.02 \
30+
--use_inbatch_negative false \
31+
--save_total_limit 1

examples/scifact/embed_pairwise_train.sh

-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ MODEL_NAME="BAAI/bge-base-en-v1.5"
22
TRAIN_DATA="Tevatron/scifact"
33
OUTPUT_DIR="./scifact/ft_out"
44

5-
65
torchrun --nproc_per_node 1 \
76
-m retrievals.pipelines.embed \
87
--output_dir $OUTPUT_DIR \

examples/scifact/encode_corpus.sh

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@ MODEL_DIR="./scifact/ft_out"
33
CORPUS=Tevatron/scifact-corpus
44
mkdir $ENCODE_CORPUS_DIR
55

6-
76
python -m retrievals.pipelines.embed \
87
--model_name_or_path $MODEL_DIR \
98
--output_dir $ENCODE_CORPUS_DIR \
10-
--encode_save_file corpus.pkl \
9+
--encoding_save_file corpus.pkl \
1110
--do_encode \
1211
--fp16 \
1312
--per_device_eval_batch_size 256 \

examples/scifact/encode_llm_corpus.sh

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
ENCODE_CORPUS_DIR=./scifact/corpus-embeddings
2+
MODEL_NAME="intfloat/e5-mistral-7b-instruct"
3+
LORA_DIR=./ft_llm_out
4+
CORPUS=Tevatron/scifact-corpus
5+
mkdir -p $ENCODE_CORPUS_DIR
6+
7+
python -m retrievals.pipelines.embed \
8+
--model_name_or_path $MODEL_NAME \
9+
--lora_path $LORA_DIR \
10+
--pooling_method last \
11+
--output_dir $ENCODE_CORPUS_DIR \
12+
--encoding_save_file corpus.pkl \
13+
--do_encode \
14+
--bf16 \
15+
--per_device_eval_batch_size 128 \
16+
--data_name_or_path $CORPUS \
17+
--query_key text \
18+
--document_instruction "Document: " \
19+
--document_max_length 256 \
20+
--is_query false

examples/scifact/encode_llm_query.sh

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
ENCODE_QUERY_DIR=./query-embeddings
2+
MODEL_NAME="intfloat/e5-mistral-7b-instruct"
3+
LORA_DIR=./ft_llm_out
4+
QUERY=Tevatron/scifact/dev
5+
mkdir -p $ENCODE_QUERY_DIR
6+
7+
python -m retrievals.pipelines.embed \
8+
--model_name_or_path $MODEL_NAME \
9+
--lora_path $LORA_DIR \
10+
--pooling_method last \
11+
--output_dir $ENCODE_QUERY_DIR \
12+
--encoding_save_file query.pkl \
13+
--do_encode \
14+
--bf16 \
15+
--per_device_eval_batch_size 256 \
16+
--data_name_or_path $QUERY \
17+
--query_key query \
18+
--query_instruction "Retrieve the possible answer for query.\nQuery: " \
19+
--query_max_length 64 \
20+
--is_query true

examples/scifact/encode_query.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ mkdir $ENCODE_QUERY_DIR
66
python -m retrievals.pipelines.embed \
77
--model_name_or_path $MODEL_DIR \
88
--output_dir $ENCODE_QUERY_DIR \
9-
--encode_save_file query.pkl \
9+
--encoding_save_file query.pkl \
1010
--do_encode \
1111
--fp16 \
1212
--per_device_eval_batch_size 256 \

examples/t2_ranking/README.md

+8-11
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
# T2_ranking
22

3-
An end-to-end example with [t2-reranking data](https://huggingface.co/datasets/C-MTEB/T2Reranking)
4-
5-
## Experiment
6-
7-
bge-base-zh-v1.5
8-
- "map": 0.6569549236524207, "mrr": 0.7683207806932297
9-
- embed/pairwise/infonce: "map": 0.7012381232799435, "mrr": 0.81575288845697
10-
11-
bge-reranker-base
12-
- "map": 0.6660360850586858, "mrr": 0.76091472303207
13-
- rerank/cross-encoder: "map": 0.6906494118852755, "mrr": 0.8064902548320916
3+
| Model | map | mrr |
4+
|--------------------|-------|-------|
5+
| bge-base-zh-v1.5 | 0.657 | 0.768 |
6+
| + **fine-tuning** | 0.701 | 0.816 |
7+
| bge-reranker-base | 0.666 | 0.761 |
8+
| + **fine-tuning** | 0.691 | 0.806 |
149

1510

1611
## 1. Prepare dataset
12+
13+
- [t2-reranking data](https://huggingface.co/datasets/C-MTEB/T2Reranking)
1714
```shell
1815
python prepare_t2ranking_data.py
1916
```

src/retrievals/data/dataset.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def __init__(
293293
dataset_language = args.dataset_language
294294
dataset_split = args.dataset_split
295295
text_key = args.query_key
296+
instruction = args.query_instruction or args.document_instruction or instruction
296297

297298
if isinstance(data_name_or_path, datasets.Dataset):
298299
self.encode_data = data_name_or_path
@@ -310,14 +311,16 @@ def __init__(
310311
self.id_key = id_key
311312
self.text_key = text_key
312313
self.instruction = instruction
313-
self.args = args
314+
if len(instruction) > 0:
315+
logger.info(f'Add prefix instruction: {self.instruction}')
314316

315317
def __len__(self):
316318
return len(self.encode_data)
317319

318320
def __getitem__(self, item) -> [str, BatchEncoding]:
319321
if self.id_key is not None:
320322
text_id, text = (self.encode_data[item][f] for f in [self.id_key, self.text_key])
323+
text = self.instruction + text
321324
encoded_text = self.tokenizer.encode_plus(
322325
text,
323326
max_length=self.max_length,
@@ -328,6 +331,7 @@ def __getitem__(self, item) -> [str, BatchEncoding]:
328331
return text_id, encoded_text
329332
else:
330333
text = self.encode_data[item][self.text_key]
334+
text = self.instruction + text
331335
encoded_text = self.tokenizer.encode_plus(
332336
text,
333337
max_length=self.max_length,

src/retrievals/models/embedding_auto.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def from_pretrained(
510510
model.print_trainable_parameters()
511511

512512
if lora_path is not None:
513-
logger.info('Load pretrained with LoRA adapter')
513+
logger.info(f'Load pretrained with LoRA adapter {lora_path}')
514514
from peft import LoraConfig, PeftModel
515515

516516
model = PeftModel.from_pretrained(model, lora_path)
@@ -689,9 +689,38 @@ def forward(
689689

690690
def _unsorted_segment_mean(self, data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor:
691691
result_shape = (num_segments, data.size(1))
692-
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1))
692+
segment_ids = segment_ids.unsqueeze(-1).expand(-1, data.size(1)) # (batch, num_embedding)
693693
result = data.new_full(result_shape, 0) # init empty result tensor
694694
count = data.new_full(result_shape, 0)
695-
result.scatter_add_(0, segment_ids, data)
695+
result.scatter_add_(0, segment_ids, data) # fill the result from data to organized segment result
696696
count.scatter_add_(0, segment_ids, torch.ones_like(data))
697697
return result / count.clamp(min=1)
698+
699+
def _sorted_segment_mean(self, data: torch.Tensor, segment_ids: torch.Tensor, num_segments: int) -> torch.Tensor:
700+
"""
701+
Compute the mean of each segment in data based on sorted segment_ids.
702+
703+
Args:
704+
data (torch.Tensor): Input data tensor of shape (batch_size, num_embedding).
705+
segment_ids (torch.Tensor): Sorted segment IDs tensor of shape (batch_size,).
706+
num_segments (int): Number of unique segments.
707+
708+
Returns:
709+
torch.Tensor: Tensor of shape (num_segments, num_embedding) containing the mean of each segment.
710+
"""
711+
result = torch.zeros((num_segments, data.size(1)), dtype=data.dtype, device=data.device)
712+
count = torch.zeros((num_segments,), dtype=torch.int32, device=data.device)
713+
714+
start_idx = 0
715+
for i in range(num_segments):
716+
# Find the range of indices corresponding to the current segment
717+
while start_idx < segment_ids.size(0) and segment_ids[start_idx] == i:
718+
start_idx += 1
719+
720+
if start_idx > 0 and segment_ids[start_idx - 1] == i:
721+
segment_slice = slice(start_idx - (start_idx - segment_ids[start_idx:].tolist().count(i)), start_idx)
722+
result[i] = data[segment_slice].sum(dim=0)
723+
count[i] = segment_slice.stop - segment_slice.start
724+
725+
result /= count.clamp(min=1).unsqueeze(-1)
726+
return result

src/retrievals/models/rerank.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], safe_serializ
553553
self.model.save_pretrained(
554554
save_directory, state_dict=state_dict_fn(state_dict), safe_serialization=safe_serialization
555555
)
556-
torch.save(state_dict_fn(self.linear.state_dict()), os.path.join(save_directory, 'colbert_linear.pt'))
556+
torch.save(state_dict_fn(self.linear.state_dict()), os.path.join(save_directory, 'linear.pt'))
557557
self.tokenizer.save_pretrained(save_directory)
558558

559559
@classmethod
@@ -570,9 +570,9 @@ def from_pretrained(
570570
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **kwargs)
571571

572572
linear_layer = nn.Linear(model.config.hidden_size, colbert_dim, dtype=torch.float32, bias=False)
573-
if os.path.exists(path=os.path.join(model_name_or_path, 'colbert_linear.pt')):
573+
if os.path.exists(path=os.path.join(model_name_or_path, 'linear.pt')):
574574
logger.info(f'Loading colbert_linear weight from {model_name_or_path}')
575-
colbert_state_dict = torch.load(os.path.join(model_name_or_path, 'colbert_linear.pt'), map_location='cpu')
575+
colbert_state_dict = torch.load(os.path.join(model_name_or_path, 'linear.pt'), map_location='cpu')
576576
linear_layer.load_state_dict(colbert_state_dict)
577577
else:
578578
logger.info('Xavier uniform random colbert linear layer')

0 commit comments

Comments
 (0)