-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcsn_eval_oai.py
More file actions
54 lines (40 loc) · 1.41 KB
/
csn_eval_oai.py
File metadata and controls
54 lines (40 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import time
import tqdm
import torch
from openai.embeddings_utils import get_embedding
from src import datasets_loader
from src.utils import retrieval_eval
CSN_GFG_DATA_PATH = "/mnt/colab_public/datasets/joao/CodeSearchNet"
EMBEDDING_MODEL_ID = "text-embedding-ada-002"
MAX_RAW_LEN = 15000
SLEEP_SECONDS_BETWEEN_QUERIES = 1.0
test_data = datasets_loader.get_dataset(
dataset_name="code_search_net",
path_to_cache=CSN_GFG_DATA_PATH,
split="validation",
maximum_raw_length=MAX_RAW_LEN,
)
source_embeddings_list = []
target_embeddings_list = []
total_embeddings = 0
for (source, target) in tqdm.tqdm(test_data, total=len(test_data), desc="embedding"):
source_embedding = torch.Tensor(get_embedding(source, engine=EMBEDDING_MODEL_ID))[
None, :
]
target_embedding = torch.Tensor(get_embedding(target, engine=EMBEDDING_MODEL_ID))[
None, :
]
source_embeddings_list.append(source_embedding)
target_embeddings_list.append(target_embedding)
time.sleep(
SLEEP_SECONDS_BETWEEN_QUERIES
) # Avoid getting rate limit errors from get_embedding()
source_embeddings = torch.cat(source_embeddings_list, 0)
target_embeddings = torch.cat(target_embeddings_list, 0)
recall_at_1, recall_at_5, mean_reciprocal_rank = retrieval_eval(
source_embeddings, target_embeddings
)
print(f"R@1: {recall_at_1}, R@5: {recall_at_5}, MRR: {mean_reciprocal_rank}, ")
"""
R@1: , R@5: , MRR:
"""