Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update basic_utils.py #5

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions src/utils/basic_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging, logging.handlers
import coloredlogs
import torch

import time

def get_logger(name, log_file_path=None, fmt="%(asctime)s %(name)s: %(message)s",
print_lev=logging.DEBUG, write_lev=logging.INFO):
Expand Down Expand Up @@ -38,7 +38,7 @@ def set_seed(seed, use_cuda=True):
torch.cuda.manual_seed_all(seed)


def compute_tiou(pred, gt):
def compute_tiou(pred, gt):
intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0]))
union = max(pred[1], gt[1]) - min(pred[0], gt[0])
return float(intersection) / (union + 1e-9)
Expand All @@ -50,7 +50,7 @@ def compute_overlap(pred, gt):
pred_is_list = isinstance(pred[0], list)
gt_is_list = isinstance(gt[0], list)
pred = pred if pred_is_list else [pred]
gt = gt if gt_is_list else [gt]
gt = gt if gt_is_list else [gt] #将其转换为列表中的列表,才能够使后面找每一对左边界最大值变为广播操作
# compute overlap
pred, gt = np.array(pred), np.array(gt)
inter_left = np.maximum(pred[:, 0, None], gt[None, :, 0])
Expand All @@ -62,7 +62,7 @@ def compute_overlap(pred, gt):
overlap = 1.0 * inter / union
# reformat output
overlap = overlap if gt_is_list else overlap[:, 0]
overlap = overlap if pred_is_list else overlap[0]
overlap = overlap if pred_is_list else overlap[0] #输出得到每一对的重叠比例
return overlap


Expand Down Expand Up @@ -99,63 +99,92 @@ def __init__(self, tiou_threshold=[0.1, 0.3, 0.5], topks=[1, 5, 10, 50, 100]):
self.topks = topks

def eval_instance(self, pred, gt, topk):
""" Compute Recall@topk at predefined tiou threshold for instance
"""_summary_

Args:
pred: predictions of starting/end position; list of [start,end]
gt: ground-truth of starting/end position; [start,end]
topk: rank of predictions; int
Return:
correct: flag of correct at predefined tiou threshold [0.3,0.5,0.7]
pred (_tensor_): _description_:(100,2)表示每个query的100个候选结果(开始和结束时刻)
gt (_tensor_): _description_ :(2) 表示每个query的真实结果
topk (_单值_): _description_:表示从100个候选中选择前topk个

Returns:
_type_: _description_
"""
correct = {str(tiou):0 for tiou in self.tiou_threshold}
find = {str(tiou):False for tiou in self.tiou_threshold}
#{'0.1': 0, '0.2': 0, '0.3': 0}, {'0.1': False, '0.2': False, '0.3': False}
if len(pred) == 0:
return correct

if len(pred) > topk:
pred = pred[:topk]


best_tiou = 0
for loc in pred:
for loc in pred: #从topk个中一个一个计算。表示每一个候选和真实的标签之间是否tiou大于阈值
# print("starting compute_tiou")
cur_tiou = compute_tiou(loc, gt)

if cur_tiou > best_tiou:
best_tiou = cur_tiou

for tiou in self.tiou_threshold:
if (not find[str(tiou)]) and (cur_tiou >= tiou):
#在所有的100个候选中只要有一个大于阈值,后面的便不用再设置,
#但是前面IOU还是要计算,因为要得到最大的IOU
if (not find[str(tiou)]) and (cur_tiou >= tiou):
correct[str(tiou)] = 1
find[str(tiou)] = True

#correct表示当前的query,其top-m个预测结果中是否有在给定阈值下正确的结果,只有是否,没有数值累计
#best_tiou表示当前query的top-m个预测结果中最好的tiou值
return correct, best_tiou

def eval(self, preds, gts):
""" Compute R@1 and R@5 at predefined tiou threshold [0.3,0.5,0.7]
Args:
pred: predictions consisting of starting/end position; list
gt: ground-truth of starting/end position; [start,end]
preds: list;元素个数为query的总数,没有了video的概念 num_all_querys,100,2
gts: list;元素个数为query的总数,没有了video的概念 num_all_querys,2
Return:
correct: flag of correct at predefined tiou threshold [0.3,0.5,0.7]
"""
num_instances = float(len(preds))
print("preds,gts:{},{}".format(len(preds),len(gts))) #CPU/GPU: 72044,72044
print("type of preds,gts:{},{}".format(type(preds),type(gts)))
eval_metric_st=time.time()
num_instances = float(len(preds)) #应该计算的是所有的query
print("num_instances: ",num_instances)
miou = 0
all_rank = dict()
for tiou in self.tiou_threshold:
for topk in self.topks:
#top-k和tiou是分别计算的
all_rank["R{}-{}".format(topk, tiou)] = 0

for pred,gt in zip(preds, gts):

#每个元素表示一个视频数据,用列表而不是张量的形式是因为每个视频的query数量不一样
#在列表中可以做到不同长度的存储,而张量不行,但是这里的评价标准是视频为单位还是query?
count=0
#对于每一个query去单独计算,每一行计算,preds,gts 72044,100,2 ; 72044,2
count_st=time.time()
for pred,gt in zip(preds, gts):
#每次拿出一个query进行计算
for topk in self.topks:
correct, iou = self.eval_instance(pred, gt, topk=topk)
#在eval_instance中计算得到的best_iou并没有参与到后续计算中
correct, iou = self.eval_instance(pred, gt, topk=topk) #因为内部是一个一个计算所以时间开销特别大?
for tiou in self.tiou_threshold:
all_rank["R{}-{}".format(topk, tiou)] += correct[str(tiou)]
count+=1
if count%1000==0:
count_ed=time.time()
print("count:{} ,time:{} ".format(count,count_ed-count_st))
count_st=time.time()


# miou += iou

for tiou in self.tiou_threshold:
#这个指标不是按照每个视频来计算,而是按照query来计算,有点奇怪
print("ending eval compute")
for tiou in self.tiou_threshold:
for topk in self.topks:
all_rank["R{}-{}".format(topk, tiou)] /= num_instances

# miou /= float(num_instances)

return all_rank, miou
eval_metric_et=time.time()
print("eval_metric time: ",eval_metric_et-eval_metric_st)
return all_rank, miou