1
- from beir import util , LoggingHandler
2
- from beir .datasets .data_loader import GenericDataLoader
3
1
from beir .retrieval .evaluation import EvaluateRetrieval
4
2
from beir .retrieval .search .lexical import BM25Search as BM25
3
+ from typing import Type , List , Dict , Union , Tuple
4
+
5
5
6
6
import pathlib , os , random , json , argparse
7
7
import logging
25
25
datefmt = "%m/%d/%Y %H:%M:%S" ,
26
26
level = logging .INFO
27
27
)
28
+
29
+
30
+ def calculate_top_k_accuracy (
31
+ qrels : Dict [str , Dict [str , int ]],
32
+ results : Dict [str , Dict [str , float ]],
33
+ k_values : List [int ]) -> Tuple [Dict [str , float ]]:
34
+
35
+ top_k_acc = {}
36
+
37
+ for k in k_values :
38
+ top_k_acc [f"Accuracy@{ k } " ] = 0.0
39
+
40
+ k_max , top_hits = max (k_values ), {}
41
+ logging .info ("\n " )
42
+
43
+ for query_id , doc_scores in results .items ():
44
+ top_hits [query_id ] = [item [0 ] for item in sorted (doc_scores .items (), key = lambda item : item [1 ], reverse = True )[0 :k_max ]]
45
+
46
+ for query_id in top_hits :
47
+ query_relevant_docs = set ([doc_id for doc_id in qrels [query_id ] if qrels [query_id ][doc_id ] > 0 ])
48
+ for k in k_values :
49
+ for relevant_doc_id in query_relevant_docs :
50
+ if relevant_doc_id in top_hits [query_id ][0 :k ]:
51
+ top_k_acc [f"Accuracy@{ k } " ] += 1.0
52
+ break
53
+
54
+ for k in k_values :
55
+ top_k_acc [f"Accuracy@{ k } " ] = round (top_k_acc [f"Accuracy@{ k } " ]/ len (qrels ), 5 )
56
+ logging .info ("Accuracy@{}: {:.4f}" .format (k , top_k_acc [f"Accuracy@{ k } " ]))
57
+
58
+ return top_k_acc
59
+
28
60
#### /print debug information to stdout
29
61
30
62
#### load dataset
64
96
results = retriever .retrieve (corpus , queries )
65
97
66
98
#### Evaluate your retrieval using NDCG@k, MAP@K ...
67
- #retriever.k_values = [1,2,3,4,5,6,7,8,9,10,20,50,100]
68
99
retriever .k_values = [2 ]
69
100
logging .info ("Retriever evaluation for k in: {}" .format (retriever .k_values ))
70
- top_k_accuracy = retriever . evaluate_custom (qrels , results , retriever .k_values , metric = "top_k_accuracy" )
101
+ top_k_accuracy = calculate_top_k_accuracy (qrels , results , retriever .k_values )
71
102
print ('top_k_accuracy: {}' .format (top_k_accuracy ))
72
103
logging .info (print ('top_k_accuracy: {}' .format (top_k_accuracy )))
73
104
79
110
scores = sorted (scores_dict .items (), key = lambda item : item [1 ], reverse = True )
80
111
for rank in range (10 ):
81
112
doc_id = scores [rank ][0 ]
82
- logging .info ("Doc %d: %s [%s] - %s\n " % (rank + 1 , doc_id , corpus [doc_id ].get ("title" ), corpus [doc_id ].get ("text" )))
113
+ logging .info ("Doc %d: %s [%s] - %s\n " % (rank + 1 , doc_id , corpus [doc_id ].get ("title" ), corpus [doc_id ].get ("text" )))
114
+
115
+
116
+
0 commit comments