Skip to content

Commit 9018fc4

Browse files
committed
add
1 parent 71e204f commit 9018fc4

5 files changed

+401
-155
lines changed

eval.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from beir import util, LoggingHandler
2-
from beir.datasets.data_loader import GenericDataLoader
31
from beir.retrieval.evaluation import EvaluateRetrieval
42
from beir.retrieval.search.lexical import BM25Search as BM25
3+
from typing import Type, List, Dict, Union, Tuple
4+
55

66
import pathlib, os, random, json, argparse
77
import logging
@@ -25,6 +25,38 @@
2525
datefmt="%m/%d/%Y %H:%M:%S",
2626
level=logging.INFO
2727
)
28+
29+
30+
def calculate_top_k_accuracy(
31+
qrels: Dict[str, Dict[str, int]],
32+
results: Dict[str, Dict[str, float]],
33+
k_values: List[int]) -> Tuple[Dict[str, float]]:
34+
35+
top_k_acc = {}
36+
37+
for k in k_values:
38+
top_k_acc[f"Accuracy@{k}"] = 0.0
39+
40+
k_max, top_hits = max(k_values), {}
41+
logging.info("\n")
42+
43+
for query_id, doc_scores in results.items():
44+
top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]]
45+
46+
for query_id in top_hits:
47+
query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
48+
for k in k_values:
49+
for relevant_doc_id in query_relevant_docs:
50+
if relevant_doc_id in top_hits[query_id][0:k]:
51+
top_k_acc[f"Accuracy@{k}"] += 1.0
52+
break
53+
54+
for k in k_values:
55+
top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5)
56+
logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"]))
57+
58+
return top_k_acc
59+
2860
#### /print debug information to stdout
2961

3062
#### load dataset
@@ -64,10 +96,9 @@
6496
results = retriever.retrieve(corpus, queries)
6597

6698
#### Evaluate your retrieval using NDCG@k, MAP@K ...
67-
#retriever.k_values = [1,2,3,4,5,6,7,8,9,10,20,50,100]
6899
retriever.k_values = [2]
69100
logging.info("Retriever evaluation for k in: {}".format(retriever.k_values))
70-
top_k_accuracy = retriever.evaluate_custom(qrels, results, retriever.k_values, metric="top_k_accuracy")
101+
top_k_accuracy = calculate_top_k_accuracy(qrels, results, retriever.k_values)
71102
print('top_k_accuracy: {}'.format(top_k_accuracy))
72103
logging.info(print('top_k_accuracy: {}'.format(top_k_accuracy)))
73104

@@ -79,4 +110,7 @@
79110
scores = sorted(scores_dict.items(), key=lambda item: item[1], reverse=True)
80111
for rank in range(10):
81112
doc_id = scores[rank][0]
82-
logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
113+
logging.info("Doc %d: %s [%s] - %s\n" % (rank+1, doc_id, corpus[doc_id].get("title"), corpus[doc_id].get("text")))
114+
115+
116+

generate.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
parser = argparse.ArgumentParser()
1313

1414
parser.add_argument("--valid_proportaion", default=0.7, type=float, help="")
15-
parser.add_argument("--keywords_num", default=10, type=int, help="")
15+
parser.add_argument("--keywords_num", default=2, type=int, help="")
1616
args = parser.parse_args()
1717

1818
def make_query_qrel_for_each_cluster(cluster_num):
@@ -60,7 +60,7 @@ def make_query_qrel_for_each_cluster(cluster_num):
6060
str_likes_with_community_likes = str_likes + ' ' + str_community_likes
6161

6262
json_query_preprocessed[user_key] = str_likes_with_community_likes
63-
63+
6464
# preprocess qrel for current cluster
6565
lst_likes_doc = [str(i) for i in json_data[user_key]['likes_doc']]
6666

@@ -87,11 +87,8 @@ def make_query_qrel_for_each_cluster(cluster_num):
8787

8888
# merge query for each cluster
8989
for k, v in i_cluster_json_query_preprocessed.items():
90-
#total_cluster_json_query_preprocessed[k] = total_cluster_json_query_preprocessed[k] + ' , ' + cluster_like_keywords if k in total_cluster_json_query_preprocessed else v
91-
9290
# if user (key) exists in both clusters, then concat two queries (value)
9391
if k in total_cluster_json_query_preprocessed:
94-
#import pdb; pdb.set_trace()
9592
total_cluster_json_query_preprocessed[k] = total_cluster_json_query_preprocessed[k] + ' , ' + cluster_like_keywords
9693
else:
9794
total_cluster_json_query_preprocessed[k] = v
@@ -101,6 +98,7 @@ def make_query_qrel_for_each_cluster(cluster_num):
10198
for k, v in i_cluster_json_qrel_preprocessed.items():
10299
total_cluster_json_qrel_preprocessed[k] = v
103100

101+
print(len(total_cluster_json_query_preprocessed))
104102
with open('/home/syjeong/Starlab/data/preprocessed/ver3/keywords_num/'+str(args.keywords_num)+'/'+'total_cluster_ver3_30_users_query_penguin.json', "w") as writer:
105103
writer.write(json.dumps(total_cluster_json_query_preprocessed, indent=4) + "\n")
106104

0 commit comments

Comments
 (0)