-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgreedy_mp.py
137 lines (119 loc) · 5.54 KB
/
greedy_mp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import logging
import math
from typing import Any, Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.nn.parallel import DataParallel
logger = logging.getLogger(__name__)
class GreedyMentionProposer(torch.nn.Module):
def __init__(
self,
**kwargs
) -> None:
super().__init__(**kwargs)
def forward(
self,
spans: torch.IntTensor,
span_mention_scores: torch.FloatTensor,
span_mask: torch.FloatTensor,
token_num: torch.IntTensor,
num_spans_to_keep: int,
take_top_spans_per_sentence = False,
flat_span_sent_ids = None,
ratio = 0.4,
):
if not take_top_spans_per_sentence:
top_span_indices = masked_topk_non_overlap(
span_mention_scores,
span_mask,
num_spans_to_keep,
spans
)
top_spans = spans[top_span_indices]
top_span_scores = span_mention_scores[top_span_indices]
return top_span_scores, top_span_indices, top_spans, 0., None
else:
top_span_indices, top_span_scores, top_spans = [], [], []
prev_sent_id, prev_span_id = 0, 0
for span_id, sent_id in enumerate(flat_span_sent_ids.tolist()):
if sent_id != prev_sent_id:
sent_span_indices = masked_topk_non_overlap(
span_mention_scores[prev_span_id:span_id],
span_mask[prev_span_id:span_id],
int(ratio * (token_num[prev_sent_id])), # [CLS], [SEP]
spans[prev_span_id:span_id],
non_crossing=True,
) + prev_span_id
top_span_indices.append(sent_span_indices)
top_span_scores.append(span_mention_scores[sent_span_indices])
top_spans.append(spans[sent_span_indices])
prev_sent_id, prev_span_id = sent_id, span_id
# last sentence
sent_span_indices = masked_topk_non_overlap(
span_mention_scores[prev_span_id:],
span_mask[prev_span_id:],
int(ratio * (token_num[-1])),
spans[prev_span_id:],
non_crossing=True,
) + prev_span_id
top_span_indices.append(sent_span_indices)
top_span_scores.append(span_mention_scores[sent_span_indices])
top_spans.append(spans[sent_span_indices])
num_top_spans = [x.size(0) for x in top_span_indices]
max_num_top_span = max(num_top_spans)
top_spans = torch.stack(
[torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), 2))], dim=0) for x in top_spans], dim=0
)
top_span_masks = torch.stack(
[torch.cat([x.new_ones((x.size(0), )), x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0
)
top_span_indices = torch.stack(
[torch.cat([x, x.new_zeros((max_num_top_span-x.size(0), ))], dim=0) for x in top_span_indices], dim=0
)
return top_span_scores, top_span_indices, top_spans, 0., None, top_span_masks
def masked_topk_non_overlap(
span_scores,
span_mask,
num_spans_to_keep,
spans,
non_crossing=True
):
sorted_scores, sorted_indices = torch.sort(span_scores + span_mask.log(), descending=True)
sorted_indices = sorted_indices.tolist()
spans = spans.tolist()
if not non_crossing:
selected_candidate_idx = sorted(sorted_indices[:num_spans_to_keep], key=lambda idx: (spans[idx][0], spans[idx][1]))
selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long)
return selected_candidate_idx
selected_candidate_idx = []
start_to_max_end, end_to_min_start = {}, {}
for candidate_idx in sorted_indices:
if len(selected_candidate_idx) >= num_spans_to_keep:
break
# Perform overlapping check
span_start_idx = spans[candidate_idx][0]
span_end_idx = spans[candidate_idx][1]
cross_overlap = False
for token_idx in range(span_start_idx, span_end_idx + 1):
max_end = start_to_max_end.get(token_idx, -1)
if token_idx > span_start_idx and max_end > span_end_idx:
cross_overlap = True
break
min_start = end_to_min_start.get(token_idx, -1)
if token_idx < span_end_idx and 0 <= min_start < span_start_idx:
cross_overlap = True
break
if not cross_overlap:
# Pass check; select idx and update dict stats
selected_candidate_idx.append(candidate_idx)
max_end = start_to_max_end.get(span_start_idx, -1)
if span_end_idx > max_end:
start_to_max_end[span_start_idx] = span_end_idx
min_start = end_to_min_start.get(span_end_idx, -1)
if min_start == -1 or span_start_idx < min_start:
end_to_min_start[span_end_idx] = span_start_idx
# Sort selected candidates by span idx
selected_candidate_idx = sorted(selected_candidate_idx, key=lambda idx: (spans[idx][0], spans[idx][1]))
selected_candidate_idx = span_scores.new_tensor(selected_candidate_idx, dtype=torch.long)
return selected_candidate_idx