|
| 1 | +"""Evaluation of embedding model""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import functools |
| 5 | +import random |
| 6 | +from typing import Dict, List |
| 7 | + |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | + |
| 11 | +# from C_MTEB.tasks import * |
| 12 | +from mteb import MTEB, DRESModel |
| 13 | + |
| 14 | +from retrievals import AutoModelForEmbedding |
| 15 | + |
| 16 | +TASKS_WITH_PROMPTS = [ |
| 17 | + "T2Retrieval", |
| 18 | + "MMarcoRetrieval", |
| 19 | + "DuRetrieval", |
| 20 | + "CovidRetrieval", |
| 21 | + "CmedqaRetrieval", |
| 22 | + "EcomRetrieval", |
| 23 | + "MedicalRetrieval", |
| 24 | + "VideoRetrieval", |
| 25 | +] |
| 26 | + |
| 27 | +parser = argparse.ArgumentParser(description='evaluation for CMTEB') |
| 28 | +parser.add_argument('--model_name', default='bert-base-uncased', type=str, help='which model to use') |
| 29 | +parser.add_argument('--output_dir', default='zh_results/', type=str, help='output directory') |
| 30 | +parser.add_argument('--max_len', default=512, type=int, help='max length') |
| 31 | + |
| 32 | +args = parser.parse_args() |
| 33 | + |
| 34 | + |
| 35 | +class RetrievalModel(DRESModel): |
| 36 | + def __init__(self, encoder, query_instruction='', document_instruction='', **kwargs): |
| 37 | + self.encoder = encoder |
| 38 | + self.query_instruction = query_instruction |
| 39 | + self.document_instruction = document_instruction |
| 40 | + |
| 41 | + def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: |
| 42 | + """For MTEB eval |
| 43 | + This function will be used for retrieval task |
| 44 | + if there is an instruction for queries, we will add it to the query text |
| 45 | + """ |
| 46 | + input_texts = [self.query_instruction + q for q in queries] |
| 47 | + return self._do_encode(input_texts) |
| 48 | + |
| 49 | + def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray: |
| 50 | + """For MTEB eval |
| 51 | + This function will be used for retrieval task |
| 52 | + encode corpus for retrieval task |
| 53 | + """ |
| 54 | + if isinstance(corpus[0], dict): |
| 55 | + input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus] |
| 56 | + else: |
| 57 | + input_texts = corpus |
| 58 | + |
| 59 | + input_texts = [self.document_instruction + t for t in input_texts] |
| 60 | + return self._do_encode(input_texts) |
| 61 | + |
| 62 | + @torch.no_grad() |
| 63 | + def _do_encode(self, input_texts: List[str]) -> np.ndarray: |
| 64 | + return self.encoder.encode( |
| 65 | + sentences=input_texts, batch_size=256, normalize_embeddings=True, convert_to_numpy=True |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +if __name__ == '__main__': |
| 70 | + encoder = AutoModelForEmbedding.from_pretrained(args.model_name) |
| 71 | + encoder.encode = functools.partial(encoder.encode, normalize_embeddings=True) |
| 72 | + |
| 73 | + task_names = [t.description["name"] for t in MTEB(task_langs=['zh', 'zh-CN']).tasks] |
| 74 | + random.shuffle(task_names) |
| 75 | + |
| 76 | + for task in task_names: |
| 77 | + evaluation = MTEB(tasks=[task], task_langs=['zh', 'zh-CN']) |
| 78 | + if task in TASKS_WITH_PROMPTS: |
| 79 | + evaluation.run(RetrievalModel(encoder), output_folder=args.output_dir, overwrite_results=False) |
| 80 | + else: |
| 81 | + evaluation.run(encoder, output_folder=args.output_dir, overwrite_results=False) |
0 commit comments