Skip to content

Commit 5bd0dd5

Browse files
authored
fix: support colbert fine-tune from pretrained model
1 parent a55316a commit 5bd0dd5

File tree

15 files changed

+271
-88
lines changed

15 files changed

+271
-88
lines changed

Diff for: README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[lint-image]: https://github.com/LongxingTan/open-retrievals/actions/workflows/lint.yml/badge.svg?branch=master
1010
[lint-url]: https://github.com/LongxingTan/open-retrievals/actions/workflows/lint.yml?query=branch%3Amaster
1111
[docs-image]: https://readthedocs.org/projects/open-retrievals/badge/?version=latest
12-
[docs-url]: https://open-retrievals.readthedocs.io/en/latest/?version=latest
12+
[docs-url]: https://open-retrievals.readthedocs.io/en/master/
1313
[coverage-image]: https://codecov.io/gh/longxingtan/open-retrievals/branch/master/graph/badge.svg
1414
[coverage-url]: https://codecov.io/github/longxingtan/open-retrievals?branch=master
1515
[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
@@ -29,7 +29,7 @@
2929
[![Code Coverage][coverage-image]][coverage-url]
3030
[![Contributing][contributing-image]][contributing-url]
3131

32-
**[Documentation](https://open-retrievals.readthedocs.io)** | **[中文](https://github.com/LongxingTan/open-retrievals/blob/master/README_zh-CN.md)** | **[日本語](https://github.com/LongxingTan/open-retrievals/blob/master/README_ja-JP.md)**
32+
**[Documentation](https://open-retrievals.readthedocs.io/en/master/)** | **[中文](https://github.com/LongxingTan/open-retrievals/blob/master/README_zh-CN.md)** | **[日本語](https://github.com/LongxingTan/open-retrievals/blob/master/README_ja-JP.md)**
3333

3434
</div>
3535

Diff for: README_ja-JP.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[lint-image]: https://github.com/LongxingTan/open-retrievals/actions/workflows/lint.yml/badge.svg?branch=master
1010
[lint-url]: https://github.com/LongxingTan/open-retrievals/actions/workflows/lint.yml?query=branch%3Amaster
1111
[docs-image]: https://readthedocs.org/projects/open-retrievals/badge/?version=latest
12-
[docs-url]: https://open-retrievals.readthedocs.io/en/latest/?version=latest
12+
[docs-url]: https://open-retrievals.readthedocs.io/en/master/
1313
[coverage-image]: https://codecov.io/gh/longxingtan/open-retrievals/branch/master/graph/badge.svg
1414
[coverage-url]: https://codecov.io/github/longxingtan/open-retrievals?branch=master
1515
[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
@@ -29,7 +29,7 @@
2929
[![Code Coverage][coverage-image]][coverage-url]
3030
[![Contributing][contributing-image]][contributing-url]
3131

32-
**[ドキュメント](https://open-retrievals.readthedocs.io)** | **[英語](https://github.com/LongxingTan/open-retrievals/blob/master/README.md)** | **[中文](https://github.com/LongxingTan/open-retrievals/blob/master/README_zh-CN.md)**
32+
**[ドキュメント](https://open-retrievals.readthedocs.io/en/master/)** | **[英語](https://github.com/LongxingTan/open-retrievals/blob/master/README.md)** | **[中文](https://github.com/LongxingTan/open-retrievals/blob/master/README_zh-CN.md)**
3333
</div>
3434

3535
![structure](./docs/source/_static/structure.png)

Diff for: README_zh-CN.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
[lint-image]: https://github.com/LongxingTan/open-retrievals/actions/workflows/lint.yml/badge.svg?branch=master
1010
[lint-url]: https://github.com/LongxingTan/open-retrievals/actions/workflows/lint.yml?query=branch%3Amaster
1111
[docs-image]: https://readthedocs.org/projects/open-retrievals/badge/?version=latest
12-
[docs-url]: https://open-retrievals.readthedocs.io/en/latest/?version=latest
12+
[docs-url]: https://open-retrievals.readthedocs.io/en/master/
1313
[coverage-image]: https://codecov.io/gh/longxingtan/open-retrievals/branch/master/graph/badge.svg
1414
[coverage-url]: https://codecov.io/github/longxingtan/open-retrievals?branch=master
1515
[contributing-image]: https://img.shields.io/badge/contributions-welcome-brightgreen.svg?style=flat
@@ -29,7 +29,7 @@
2929
[![Code Coverage][coverage-image]][coverage-url]
3030
[![Contributing][contributing-image]][contributing-url]
3131

32-
**[中文wiki](https://github.com/LongxingTan/open-retrievals/wiki)** | **[英文文档](https://open-retrievals.readthedocs.io)**
32+
**[中文wiki](https://github.com/LongxingTan/open-retrievals/wiki)** | **[英文文档](https://open-retrievals.readthedocs.io/en/master/)**
3333
</div>
3434

3535
![structure](./docs/source/_static/structure.png)
@@ -80,8 +80,8 @@ from retrievals import AutoModelForEmbedding
8080

8181
sentences = [
8282
"在1974年,第一次在东南亚打自由搏击就得了冠军",
83-
"1982年打赢了日本重炮手雷龙,接着连续三年打败所有日本空手道高手,赢得全日本自由搏击冠军",
8483
"中国古拳法唯一传人鬼王达,被喻为空手道的克星,绰号魔鬼筋肉人",
84+
"1982年打赢了日本重炮手雷龙,接着连续三年打败所有日本空手道高手,赢得全日本自由搏击冠军",
8585
"古人有云,有功夫,无懦夫"
8686
]
8787

@@ -97,12 +97,12 @@ print(scores.tolist())
9797
from retrievals import AutoModelForEmbedding, AutoModelForRetrieval
9898

9999
index_path = './database/faiss/faiss.index'
100-
sentences = ['A dog is chasing car.', 'A man is playing a guitar.']
100+
sentences = ['在中国是中国人', '在美国是美国人', '2000人民币大于3000美元']
101101
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
102102
model = AutoModelForEmbedding.from_pretrained(model_name_or_path)
103103
model.build_index(sentences, index_path=index_path)
104104

105-
query_embed = model.encode("He plays guitar.")
105+
query_embed = model.encode("在加拿大是加拿大人")
106106
matcher = AutoModelForRetrieval()
107107
dists, indices = matcher.search(query_embed, index_path=index_path)
108108
print(indices)

Diff for: docs/source/embed.rst

+17-22
Original file line numberDiff line numberDiff line change
@@ -3,50 +3,45 @@ Embedding
33

44
.. _embed:
55

6-
Use embedding from open-retrievals
6+
1. Use embedding from open-retrievals
77
---------------------------------------
88

99
we can use `AutoModelForEmbedding` to get the sentence embedding from pretrained transformer or large language model.
1010

1111
The Transformer model could get the representation vector from a sentence.
1212

13-
14-
.. epigraph::
15-
:align: left
16-
1713
Choose the right `pooling_method` when use the pretrained embedding, check in `huggingface <https://huggingface.co/models>`_
1814

1915

20-
Fine-tune
16+
2. Fine-tune
2117
------------------
2218

23-
point-wise
19+
- point-wise
2420

25-
- `{(query, label), (document, label)}`
21+
`{(query, label), (document, label), ...}`
2622

2723

28-
pairwise
24+
- pairwise
2925

30-
- `{(query, positive, label), (query, negative, label)}`
26+
`{(query, positive, negative), {query, positive, negative}, ...}`
3127

32-
- `{(query, positive, negative), {query, positive, negative}}`
28+
`{(query, positive, negative1, negative2, negative3), (query, positive, negative1, negative2, negative3), ...}`
3329

34-
- `{(query, positive, negative1, negative2, negative3...)}`
30+
`{(query, positive, label), (query, negative, label), ...}`
3531

36-
listwise
37-
38-
- `{(query+positive)}`
32+
- listwise
3933

4034

4135
Loss function
4236
~~~~~~~~~~~~~~~~~~~~~~
4337

44-
- binary classification:
45-
- similarity(query, positive) > similarity(query, negative)
46-
- hinge loss: max(0, similarity(query, positive) - similarity(query, negative) + margin)
47-
- logistic loss: logistic(similarity(query, positive) - similarity(query, negative))
48-
- multi-label classification:
49-
- similarity(query, positive), similarity(query, negative1), similarity(query, negative2)
38+
binary classification:
39+
- similarity(query, positive) > similarity(query, negative)
40+
- hinge loss: max(0, similarity(query, positive) - similarity(query, negative) + margin)
41+
- logistic loss: logistic(similarity(query, positive) - similarity(query, negative))
42+
43+
multi-label classification:
44+
- similarity(query, positive), similarity(query, negative1), similarity(query, negative2)
5045

5146

5247
Pair wise
@@ -112,7 +107,7 @@ arcface
112107
List wise
113108
~~~~~~~~~~~~~~
114109

115-
Training skills
110+
3. Training skills
116111
-----------------------------------
117112

118113
multiple gpus

Diff for: docs/source/index.rst

+8-2
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,15 @@ Run a simple example
3838

3939
.. code-block:: python
4040
41-
import retrievals
41+
from retrievals import AutoModelForEmbedding
4242
43+
sentences = ["Hello NLP", "Open-retrievals is designed for retrieval, rerank and RAG"]
44+
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
45+
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
46+
sentence_embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
47+
print(sentence_embeddings)
4348
49+
Open-retrievals support to fine-tune the embedding model, reranking model, llm easily for custom usage.
4450

4551
* `Pairwise embedding fine-tuning <https://github.com/LongxingTan/open-retrievals/blob/master/examples/embedding_pairwise_finetune.py>`_
4652
* `Pairwise LLM embedding fine-tuning <https://github.com/LongxingTan/open-retrievals/blob/master/examples/embedding_llm_finetune.py>`_
@@ -49,7 +55,7 @@ Run a simple example
4955
* `LLM reranking fine-tuning <https://github.com/LongxingTan/open-retrievals/blob/master/examples/rerank_llm_finetune.py>`_
5056

5157

52-
More datasets
58+
More datasets examples
5359

5460
* `T2 ranking dataset <https://github.com/LongxingTan/open-retrievals/tree/master/examples/t2_ranking>`_
5561
* `scifact dataset <https://github.com/LongxingTan/open-retrievals/tree/master/examples/scifact>`_

Diff for: docs/source/quick-start.rst

+64-8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ Quick start
55

66
We can easily use Open-retrievals to fine-tune the model easily for information retrieval and RAG application.
77

8+
.. image:: https://colab.research.google.com/assets/colab-badge.svg
9+
:target: https://colab.research.google.com/drive/1-WBMisdWLeHUKlzJ2DrREXY_kSV8vjP3?usp=sharing
10+
:alt: Open In Colab
11+
812

913
1. Embedding
1014
-----------------------------
@@ -15,18 +19,26 @@ We can use the pretrained embedding easily from transformers or sentence-transfo
1519
1620
from retrievals import AutoModelForEmbedding
1721
18-
sentences = ["Hello NLP", "Open-retrievals is designed for retrieval, rerank and RAG"]
19-
model_name_or_path = "sentence-transformers/all-MiniLM-L6-v2"
22+
sentences = [
23+
'query: how much protein should a female eat',
24+
'query: summit define',
25+
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. ",
26+
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level."
27+
]
28+
model_name_or_path = 'intfloat/e5-base-v2'
29+
# sentence embedding mode
2030
model = AutoModelForEmbedding.from_pretrained(model_name_or_path, pooling_method="mean")
21-
sentence_embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
22-
print(sentence_embeddings)
31+
# encode the sentence to embedding vector
32+
embeddings = model.encode(sentences, normalize_embeddings=True, convert_to_tensor=True)
33+
scores = (embeddings[:2] @ embeddings[2:].T) * 100
34+
print(scores.tolist())
2335
2436
.. code::
2537
26-
output
38+
[[89.92379760742188, 68.0742416381836], [68.93356323242188, 91.32250213623047]]
2739
2840
29-
Embedding fine-tuned
41+
Fine-tune embedding
3042
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
3143

3244
If we want to further improve the retrieval performance, an optional method is to fine tune the embedding model weights. It will project the vector of query and answer to similar representation space.
@@ -99,15 +111,59 @@ If we have multiple retrieval source or a better sequence, we can add the rerank
99111
100112
from retrievals import AutoModelForRanking
101113
114+
sentences = [
115+
["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."],
116+
['A dog is chasing car.', 'A man is playing a guitar.'],
117+
]
118+
102119
model_name_or_path: str = "BAAI/bge-reranker-base"
103120
rerank_model = AutoModelForRanking.from_pretrained(model_name_or_path)
104-
scores_list = rerank_model.compute_score(["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."])
121+
scores_list = rerank_model.compute_score(sentences)
105122
print(scores_list)
106123
124+
.. code::
107125
108-
Rerank fine-tuned
126+
[-5.075257778167725, -10.194067001342773]
127+
128+
129+
Fine-tune reranking
109130
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
110131

132+
.. code-block:: python
133+
134+
from transformers import AutoTokenizer, TrainingArguments, get_cosine_schedule_with_warmup, AdamW
135+
from retrievals import RerankCollator, AutoModelForRanking, RerankTrainer, RerankTrainDataset
136+
137+
model_name_or_path: str = "microsoft/deberta-v3-base"
138+
max_length: int = 128
139+
learning_rate: float = 3e-5
140+
batch_size: int = 4
141+
epochs: int = 3
142+
143+
train_dataset = RerankTrainDataset('./t2rank.json', positive_key='pos', negative_key='neg')
144+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
145+
model = AutoModelForRanking.from_pretrained(model_name_or_path)
146+
optimizer = AdamW(model.parameters(), lr=learning_rate)
147+
num_train_steps = int(len(train_dataset) / batch_size * epochs)
148+
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=0.05 * num_train_steps, num_training_steps=num_train_steps)
149+
150+
training_args = TrainingArguments(
151+
learning_rate=learning_rate,
152+
per_device_train_batch_size=batch_size,
153+
num_train_epochs=epochs,
154+
output_dir='./checkpoints',
155+
remove_unused_columns=False,
156+
)
157+
trainer = RerankTrainer(
158+
model=model,
159+
args=training_args,
160+
train_dataset=train_dataset,
161+
data_collator=RerankCollator(tokenizer, max_length=max_length),
162+
)
163+
trainer.optimizer = optimizer
164+
trainer.scheduler = scheduler
165+
trainer.train()
166+
111167
112168
4. RAG
113169
-----------------------------

Diff for: docs/source/rerank.rst

+26-9
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,30 @@ Rerank
33

44
.. _rerank:
55

6-
Use Rerank from open-retrievals
7-
------------------------------------
6+
1. Use reranking from open-retrievals
7+
-------------------------------------------
88

99
.. code-block:: python
1010
1111
from retrievals import AutoModelForRanking
1212
13+
sentences = [
14+
["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."],
15+
['A dog is chasing car.', 'A man is playing a guitar.'],
16+
]
17+
1318
model_name_or_path: str = "BAAI/bge-reranker-base"
1419
rerank_model = AutoModelForRanking.from_pretrained(model_name_or_path)
15-
scores_list = rerank_model.compute_score(["In 1974, I won the championship in Southeast Asia in my first kickboxing match", "In 1982, I defeated the heavy hitter Ryu Long."])
20+
scores_list = rerank_model.compute_score(sentences)
1621
print(scores_list)
1722
23+
.. code::
24+
25+
[-5.075257778167725, -10.194067001342773]
1826
1927
20-
Fine tuning Cross-encoder
21-
----------------------------
28+
2. Fine-tune cross-encoder reranking model
29+
-----------------------------------------------
2230

2331
.. image:: https://colab.research.google.com/assets/colab-badge.svg
2432
:target: https://colab.research.google.com/drive/1QvbUkZtG56SXomGYidwI4RQzwODQrWNm?usp=sharing
@@ -61,12 +69,21 @@ Fine tuning Cross-encoder
6169
trainer.train()
6270
6371
64-
Fine tuning ColBERT
65-
----------------------------
72+
3. Fine-tune ColBERT reranking model
73+
----------------------------------------
74+
75+
.. image:: https://colab.research.google.com/assets/colab-badge.svg
76+
:target: https://colab.research.google.com/drive/1QVtqhQ080ZMltXoJyODMmvEQYI6oo5kO?usp=sharing
77+
:alt: Open In Colab
6678

6779

68-
Fine tuning LLM ranker
69-
----------------------------
80+
4. Fine-tune LLM reranker
81+
-------------------------------------
82+
83+
.. image:: https://colab.research.google.com/assets/colab-badge.svg
84+
:target: https://colab.research.google.com/drive/1fzq1iV7-f8hNKFnjMmpVhVxadqPb9IXk?usp=sharing
85+
:alt: Open In Colab
86+
7087

7188
- Point-wise style prompt:
7289

Diff for: examples/README.md

+12
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,20 @@ Check the whole pipeline examples
1616
## Embedding
1717

1818
**Data Format**
19+
20+
- In-batch negative fine-tuning
21+
```
22+
{'query': TEXT_TYPE, 'positive': List[TEXT_TYPE]}
23+
...
24+
```
25+
26+
- Hard negative (+ In-batch negative) fine-tuning
1927
```
2028
{'query': TEXT_TYPE, 'positive': List[TEXT_TYPE], 'negative': List[TEXT_TYPE]}
2129
...
2230
```
2331

32+
2433
**Pairwise embedding finetune**
2534
```shell
2635
MODEL_NAME="BAAI/bge-base-zh-v1.5"
@@ -216,3 +225,6 @@ The grad_norm during training is always zero?
216225
The fine-tuned embedding performance during inference is worse than original?
217226
- check whether the pooling_method is correct
218227
- check whether the prompt is the same as training for LLM model
228+
229+
How can we fine-tune the `BAAI/bge-m3` ColBERT model?
230+
- download the weights first using `snapshot_download` from huggingface_hub to model_dir, then use ColBERT.from_pretrained(model_dir)

Diff for: src/retrievals/models/embedding_auto.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111
from torch.utils.data import DataLoader
12-
from tqdm.auto import tqdm
13-
from tqdm.autonotebook import trange
12+
from tqdm.auto import tqdm, trange
1413
from transformers import (
1514
AutoConfig,
1615
AutoModel,

0 commit comments

Comments
 (0)