Skip to content

Commit 9094bca

Browse files
authored
fix: t2_ranking example
1 parent aceef3a commit 9094bca

23 files changed

+291
-56
lines changed

.gitignore

+2-7
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ coverage.xml
8282
.hypothesis/
8383
.pytest_cache/
8484

85-
# yuetan
85+
# specific
8686
**/nohup.out
8787
/reference/*
8888
/examples/data/*
@@ -92,12 +92,7 @@ coverage.xml
9292
/data/*.zip
9393
/data/raw/*
9494
/data/web/*
95-
/weights/scaler.pkl
96-
/weights/saved_model.pb
97-
/weights/variables/*
98-
/weights/checkpoint
99-
/weights/checkpoint.data-00000-of-00001
100-
/weights/checkpoint.index
95+
/weights/*
10196
/conda/*
10297
**/.pdf
10398
/encode.py

README_zh-CN.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
**Open-Retrievals** 一统向量、检索、重排,帮助开发者在信息检索、大语言模型RAG等领域便捷优化
3838
- 支持全套向量微调,对比学习、大模型、point-wise、pairwise、listwise
3939
- 支持全套重排微调,cross encoder、ColBERT、LLM
40-
- 支持定制化RAG,支持在Transformers、Langchain、LlamaIndex中便捷使用微调后的模型
40+
- 支持定制化、模块化RAG,支持在Transformers、Langchain、LlamaIndex中便捷使用微调后的模型
4141

4242
| 实验 | 模型 | 尺寸| 原分数 | 微调分数 | Demo代码 |
4343
|----------------------|-------------------------|----|-------|-----------|-------------------------------------------------------------------------------------------------------------------------------------|

codecov.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ coverage:
55
status:
66
project:
77
default:
8-
threshold: 2%
8+
threshold: 3%
99

1010
patch:
1111
default:

docs/source/embed.rst

+9-2
Original file line numberDiff line numberDiff line change
@@ -124,5 +124,12 @@ offline hard mining
124124
online hard mining
125125

126126

127-
Ensemble embedding
128-
~~~~~~~~~~~~~~~~~~~~~~
127+
Matryoshka Representation Learning
128+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
129+
130+
131+
Contrastive loss
132+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
133+
134+
cosent loss
135+
- similar to circle loss, but with cosine

docs/source/rag.rst

+3
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ Enhance RAG Performance
105105
* Meta data of documents
106106

107107

108+
Graph RAG
109+
-------------------
110+
108111

109112
pdf parse
110113
--------------

docs/source/retrieval.rst

+17
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,23 @@ Offline indexing
66
----------------------------
77

88

9+
Ensemble retrieval
10+
---------------------
11+
12+
we can use `RRF_fusion` to ensemble multiple retrievals to improve the retrieval performance.
13+
914

1015
Query retrieval
1116
----------------------------
17+
18+
19+
Faiss retrieval
20+
-----------------------
21+
22+
23+
BM25 retrieval
24+
-----------------------
25+
26+
27+
Elastic search retrieval
28+
---------------------------

examples/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
- [rerank-llm finetune](rerank_llm_finetune.py)
88
- [RAG with Langchain](./rag_langchain_demo.py)
99

10-
Check the whole pipeline
10+
Check the whole pipeline examples
1111
- [t2-ranking dataset](./t2_ranking/README.md)
1212
- [scifact dataset](./scifact/README.md)
1313

examples/eval/README.md

+3-32
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,9 @@ pip install datasets mteb[beir]
66
pip install open-retrievals[eval]
77
```
88

9-
10-
```python
11-
from typing import List, Union, Dict
12-
import numpy as np
13-
from retrievals import AutoModelForEmbedding
14-
15-
16-
class AutoModelForEmbeddingEval(AutoModelForEmbedding):
17-
def __init__(self, **kwargs):
18-
super(AutoModelForEmbeddingEval, self).__init__(**kwargs)
19-
20-
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
21-
"""For MTEB eval
22-
This function will be used for retrieval task
23-
if there is an instruction for queries, we will add it to the query text
24-
"""
25-
if self.query_instruction is not None:
26-
input_texts = ['{}{}'.format(self.query_instruction, q) for q in queries]
27-
else:
28-
input_texts = queries
29-
return self.encode_from_text(input_texts, batch_size=4)
30-
31-
def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray:
32-
"""For MTEB eval
33-
This function will be used for retrieval task
34-
encode corpus for retrieval task
35-
"""
36-
if isinstance(corpus[0], dict):
37-
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
38-
else:
39-
input_texts = corpus
40-
return self.encode_from_text(input_texts, batch_size=4)
9+
**Eval**
10+
```shell
11+
python run_eval.py --model_name stella-base-zh --output_dir ./zh_results/stella-base
4112
```
4213

4314

examples/eval/run_eval.py

+81
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
"""Evaluation of embedding model"""
2+
3+
import argparse
4+
import functools
5+
import random
6+
from typing import Dict, List
7+
8+
import numpy as np
9+
import torch
10+
11+
# from C_MTEB.tasks import *
12+
from mteb import MTEB, DRESModel
13+
14+
from retrievals import AutoModelForEmbedding
15+
16+
TASKS_WITH_PROMPTS = [
17+
"T2Retrieval",
18+
"MMarcoRetrieval",
19+
"DuRetrieval",
20+
"CovidRetrieval",
21+
"CmedqaRetrieval",
22+
"EcomRetrieval",
23+
"MedicalRetrieval",
24+
"VideoRetrieval",
25+
]
26+
27+
parser = argparse.ArgumentParser(description='evaluation for CMTEB')
28+
parser.add_argument('--model_name', default='bert-base-uncased', type=str, help='which model to use')
29+
parser.add_argument('--output_dir', default='zh_results/', type=str, help='output directory')
30+
parser.add_argument('--max_len', default=512, type=int, help='max length')
31+
32+
args = parser.parse_args()
33+
34+
35+
class RetrievalModel(DRESModel):
36+
def __init__(self, encoder, query_instruction='', document_instruction='', **kwargs):
37+
self.encoder = encoder
38+
self.query_instruction = query_instruction
39+
self.document_instruction = document_instruction
40+
41+
def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
42+
"""For MTEB eval
43+
This function will be used for retrieval task
44+
if there is an instruction for queries, we will add it to the query text
45+
"""
46+
input_texts = [self.query_instruction + q for q in queries]
47+
return self._do_encode(input_texts)
48+
49+
def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
50+
"""For MTEB eval
51+
This function will be used for retrieval task
52+
encode corpus for retrieval task
53+
"""
54+
if isinstance(corpus[0], dict):
55+
input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus]
56+
else:
57+
input_texts = corpus
58+
59+
input_texts = [self.document_instruction + t for t in input_texts]
60+
return self._do_encode(input_texts)
61+
62+
@torch.no_grad()
63+
def _do_encode(self, input_texts: List[str]) -> np.ndarray:
64+
return self.encoder.encode(
65+
sentences=input_texts, batch_size=256, normalize_embeddings=True, convert_to_numpy=True
66+
)
67+
68+
69+
if __name__ == '__main__':
70+
encoder = AutoModelForEmbedding.from_pretrained(args.model_name)
71+
encoder.encode = functools.partial(encoder.encode, normalize_embeddings=True)
72+
73+
task_names = [t.description["name"] for t in MTEB(task_langs=['zh', 'zh-CN']).tasks]
74+
random.shuffle(task_names)
75+
76+
for task in task_names:
77+
evaluation = MTEB(tasks=[task], task_langs=['zh', 'zh-CN'])
78+
if task in TASKS_WITH_PROMPTS:
79+
evaluation.run(RetrievalModel(encoder), output_folder=args.output_dir, overwrite_results=False)
80+
else:
81+
evaluation.run(encoder, output_folder=args.output_dir, overwrite_results=False)

examples/msmacro/README.md

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# msmacro
2+
3+
## Download the data
4+
- [msmacro data](https://microsoft.github.io/msmarco/Datasets.html)
5+
6+
```shell
7+
sh download_data.sh
8+
```
9+
10+
## Prepare data
11+
```shell
12+
python prepare_data.py
13+
```
14+
15+
16+
## Evaluation

examples/msmacro/download_data.sh

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
2+
wget https://msmarco.z22.web.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
3+
wget https://msmarco.z22.web.core.windows.net/msmarcoranking/triples.train.small.tar.gz
4+
wget https://msmarco.z22.web.core.windows.net/msmarcoranking/top1000.eval.tar.gz
5+
6+
tar -xzvf top1000.eval.tar.gz
7+
tar -xzvf triples.train.small.tar.gz
8+
tar -xzvf collectionandqueries.tar.gz
9+
rm *.gz

examples/t2_ranking/README.md

+7-10
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,25 @@ python prepare_t2ranking_data.py
1818
## 2. Finetune embedding
1919

2020
```shell
21-
sh pairwise_embed_train.sh
21+
sh embed_pairwise_train.sh
2222
```
2323

24-
## Indexing
25-
Encode corpus
2624
```shell
27-
sh encode_corpus.sh
25+
sh embed_llm_train.sg
2826
```
2927

30-
Encode Query
28+
29+
## Rerank
3130
```shell
32-
sh encode_query.sh
31+
sh rerank_cross_encoder.sh
3332
```
3433

35-
## Retrieve
3634
```shell
37-
sh retrieve.sh
35+
sh rerank_colbert.sh
3836
```
3937

40-
## Rerank
4138
```shell
42-
sh rerank.sh
39+
sh rerank_llm.sh
4340
```
4441

4542
## Evaluate
+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
MODEL_NAME="Qwen/Qwen2-1.5B-Instruct"
2+
TRAIN_DATA="/t2_ranking.jsonl"
3+
OUTPUT_DIR="/t2_output"
4+
5+
torchrun --nproc_per_node 1 \
6+
-m retrievals.pipelines.embed \
7+
--output_dir $OUTPUT_DIR \
8+
--overwrite_output_dir \
9+
--model_name_or_path $MODEL_NAME \
10+
--pooling_method last \
11+
--do_train \
12+
--data_name_or_path $TRAIN_DATA \
13+
--positive_key positive \
14+
--negative_key negative \
15+
--use_lora True \
16+
--query_instruction "Retrieve the possible answer for query.\nQuery: " \
17+
--document_instruction 'Document: ' \
18+
--learning_rate 1e-4 \
19+
--bf16 \
20+
--num_train_epochs 3 \
21+
--per_device_train_batch_size 4 \
22+
--gradient_accumulation_steps 16 \
23+
--dataloader_drop_last True \
24+
--query_max_length 64 \
25+
--document_max_length 256 \
26+
--train_group_size 4 \
27+
--logging_strategy steps \
28+
--logging_steps 100 \
29+
--temperature 0.02 \
30+
--use_inbatch_negative false \
31+
--save_total_limit 1

examples/t2_ranking/rerank.sh

Whitespace-only changes.

examples/t2_ranking/rerank_colbert.sh

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
MODEL_NAME='hfl/chinese-roberta-wwm-ext'
2+
TRAIN_DATA="t2_ranking.jsonl"
3+
OUTPUT_DIR="t2_output"
4+
5+
torchrun --nproc_per_node 1 \
6+
--module retrievals.pipelines.rerank \
7+
--output_dir $OUTPUT_DIR \
8+
--overwrite_output_dir \
9+
--model_name_or_path $MODEL_NAME \
10+
--tokenizer_name $MODEL_NAME \
11+
--model_type colbert \
12+
--do_train \
13+
--data_name_or_path $TRAIN_DATA \
14+
--positive_key positive \
15+
--negative_key negative \
16+
--learning_rate 5e-5 \
17+
--bf16 \
18+
--num_train_epochs 5 \
19+
--per_device_train_batch_size 32 \
20+
--dataloader_drop_last True \
21+
--max_length 256 \
22+
--train_group_size 4 \
23+
--unfold_each_positive false \
24+
--save_total_limit 1 \
25+
--logging_steps 100 \
26+
--use_inbatch_negative False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
MODEL_NAME="BAAI/bge-reranker-base"
2+
TRAIN_DATA="t2_ranking.jsonl"
3+
OUTPUT_DIR="t2_rank_output"
4+
5+
torchrun --nproc_per_node 1 \
6+
-m retrievals.pipelines.rerank \
7+
--output_dir $OUTPUT_DIR \
8+
--overwrite_output_dir \
9+
--model_name_or_path $MODEL_NAME \
10+
--model_type cross-encoder \
11+
--do_train \
12+
--data_name_or_path $TRAIN_DATA \
13+
--positive_key positive \
14+
--negative_key negative \
15+
--learning_rate 2e-5 \
16+
--fp16 \
17+
--num_train_epochs 3 \
18+
--per_device_train_batch_size 64 \
19+
--dataloader_drop_last True \
20+
--max_length 512 \
21+
--save_total_limit 1 \
22+
--logging_steps 100

examples/t2_ranking/rerank_llm.sh

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
MODEL_NAME="Qwen/Qwen2-1.5B-Instruct"
2+
TRAIN_DATA="t2_ranking.jsonl"
3+
OUTPUT_DIR="t2_output"
4+
5+
torchrun --nproc_per_node 1 \
6+
-m retrievals.pipelines.rerank \
7+
--output_dir ${OUTPUT_DIR} \
8+
--overwrite_output_dir \
9+
--model_name_or_path $MODEL_NAME \
10+
--model_type llm \
11+
--causal_lm True \
12+
--use_lora True \
13+
--data_name_or_path $TRAIN_DATA \
14+
--task_prompt "Given a query A and a passage B, determine whether the passage contains an answer to the query by providing a prediction of either 'Yes' or 'No'." \
15+
--query_instruction "A: " \
16+
--document_instruction 'B: ' \
17+
--positive_key positive \
18+
--negative_key negative \
19+
--learning_rate 2e-4 \
20+
--num_train_epochs 3 \
21+
--per_device_train_batch_size 4 \
22+
--gradient_accumulation_steps 16 \
23+
--dataloader_drop_last True \
24+
--max_len 256 \
25+
--train_group_size 4 \
26+
--logging_steps 10 \
27+
--save_steps 20000 \
28+
--save_total_limit 1 \
29+
--bf16

examples/t2_ranking/retrieve.sh

Whitespace-only changes.

0 commit comments

Comments
 (0)