-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
48 lines (38 loc) · 1.2 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import torch
import json
import os
import pickle
import random
import numpy as np
def reset_random_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
def generate_casual_mask(sz, device):
mask = (torch.triu(torch.ones(sz, sz, device=device)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float(
'-inf')).masked_fill(mask == 1, float(0.0))
return mask
def generate_padding_mask(batch_lengths, device):
max_len = torch.max(batch_lengths)
mask = torch.arange(max_len, device=device).expand(
len(batch_lengths), max_len)
mask = mask < batch_lengths.unsqueeze(1)
return mask
def serialize(obj, path, in_json=False):
if in_json:
with open(path, "w") as file:
json.dump(obj, file, indent=2)
else:
with open(path, 'wb') as file:
pickle.dump(obj, file)
def unserialize(path):
suffix = os.path.basename(path).split(".")[-1]
if suffix == "json":
with open(path, "r") as file:
return json.load(file)
else:
with open(path, 'rb') as file:
return pickle.load(file)