-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
161 lines (123 loc) · 5.75 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# Dataset class for pre-extracted S3D features
import numpy as np
import pandas as pd
import torch
from torch.nn.utils.rnn import pad_packed_sequence, pack_sequence
def collate_variable_length_seq(batch, padding_value=0, modality='rgb'):
data = [[], []] if modality == 'rgb_flow' else []
labels = {}
metadata = {}
for item in batch:
d, l, m = item
if modality == 'rgb_flow':
rgb, flow = d
data[0].append(rgb)
data[1].append(flow)
else:
data.append(d)
for stuff, stuff_batch in zip((labels, metadata), (l, m)):
if len(stuff) == 0:
for k in stuff_batch.keys():
stuff[k] = []
for k, v in stuff_batch.items():
stuff[k].append(v)
if modality == 'rgb_flow':
padded_sequence = []
for modality_data in data:
mod_seq = pack_sequence(modality_data, enforce_sorted=False)
mod_pad_seq, _ = pad_packed_sequence(mod_seq, batch_first=True, padding_value=padding_value)
padded_sequence.append(mod_pad_seq)
padded_sequence = tuple(padded_sequence)
else:
if isinstance(data[0], dict):
keys = data[0].keys()
padded_sequence = {}
for k in keys:
l = [x[k] for x in data]
seq = pack_sequence(l, enforce_sorted=False)
padded_sequence[k], _ = pad_packed_sequence(seq, batch_first=True, padding_value=padding_value)
else:
sequence = pack_sequence(data, enforce_sorted=False)
padded_sequence, _ = pad_packed_sequence(sequence, batch_first=True, padding_value=padding_value)
labels = {k: torch.LongTensor(v) for k, v in labels.items()}
return padded_sequence, labels, metadata
def get_verbs_adverbs_pairs(train_df, test_df):
df = pd.concat([train_df, test_df])
adverbs = df.clustered_adverb.value_counts().index.to_list() # sorted by frequency
verbs = df.clustered_verb.value_counts().index.to_list()
adverb2idx = {a: i for i, a in enumerate(adverbs)}
idx2adverb = {i: a for i, a, in enumerate(adverbs)}
verb2idx = {a: i for i, a in enumerate(verbs)}
idx2verb = {i: a for i, a in enumerate(verbs)}
pairs = []
for a in adverbs:
for v in verbs:
pairs.append((a, v))
dataset_data = {'adverbs': adverbs, 'verbs': verbs,
'adverb2idx': adverb2idx, 'idx2adverb': idx2adverb,
'verb2idx': verb2idx, 'idx2verb': idx2verb, 'pairs': pairs}
return dataset_data
class Dataset(torch.utils.data.Dataset):
def __init__(self, df, antonyms_df, features_dict, dataset_data, feature_dim, no_antonyms=False):
self.df = df
self.random_generator = np.random.default_rng()
self.no_antonyms = no_antonyms
self.antonyms = {t.adverb: t.antonym for t in antonyms_df.itertuples()}
self.features = features_dict['features']
self.metadata = features_dict['metadata']
self.feature_dim = feature_dim
self.adverbs = dataset_data['adverbs']
self.verbs = dataset_data['verbs']
self.pairs = dataset_data['pairs']
self.adverb2idx = dataset_data['adverb2idx']
self.verb2idx = dataset_data['verb2idx']
self.idx2verb = dataset_data['idx2verb']
self.idx2adverb = dataset_data['idx2adverb']
self.dataset_data = dataset_data
def __len__(self):
return len(self.df)
def get_verb_adv_pair_idx(self, labels):
v_str = [self.idx2verb[x.item() if isinstance(x, torch.Tensor) else x] for x in labels['verb']]
a_str = [self.idx2adverb[x.item() if isinstance(x, torch.Tensor) else x] for x in labels['adverb']]
va_idx = [self.pairs.index((a, v)) for a, v in zip(a_str, v_str)]
return va_idx
def get_adverb_with_verb(self, verb):
verb = verb.item() if isinstance(verb, torch.Tensor) else verb
verb_str = verb if isinstance(verb, str) else self.idx2verb[verb]
return [a for a, v in self.pairs if v == verb_str]
def get_verb_with_adverb_mask(self, adverb):
adverb = adverb.item() if isinstance(adverb, torch.Tensor) else adverb
adverb_str = adverb if isinstance(adverb, str) else self.idx2adverb[adverb]
return [a == adverb_str for a, v in self.pairs]
def __getitem__(self, item):
if isinstance(item, torch.Tensor):
item = item.item()
segment = self.df.iloc[item]
verb_label = segment.verb_label
adverb_label = segment.adverb_label
adverb = segment.clustered_adverb
labels = dict(verb=verb_label, adverb=adverb_label)
metadata = {k: getattr(segment, k) for k in ('seg_id', 'start_time', 'end_time', 'clustered_adverb',
'clustered_verb') if hasattr(segment, k)}
data, frame_samples = self.load_features(segment)
metadata['frame_samples'] = frame_samples
if self.no_antonyms:
pool = [ai for aa, ai in self.adverb2idx.items() if aa != adverb]
neg_adverb = np.random.choice(pool, 1)
else:
neg_adverb = self.adverb2idx[self.antonyms[adverb]]
assert adverb_label != neg_adverb
metadata['negative_adverb'] = neg_adverb
return data, labels, metadata
def load_features(self, segment):
uid = self.get_seg_id(segment)
features = self.features[uid]
frame_samples = self.metadata[uid]['frame_samples'].squeeze()
return features, frame_samples
@staticmethod
def get_seg_id(segment, to_str=True):
seg_id = segment['seg_id'] if isinstance(segment, dict) else getattr(segment, 'seg_id')
if to_str:
return str(seg_id)
else:
return seg_id