Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
May 2022 update
Browse files Browse the repository at this point in the history
  • Loading branch information
hfxunlp committed May 7, 2022
1 parent eb77366 commit f3e4369
Show file tree
Hide file tree
Showing 135 changed files with 1,513 additions and 732 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Neutron
Neutron: A pytorch based implementation of the [Transformer](https://arxiv.org/abs/1706.03762) and its variants.

This project is developed with python 3.8.
This project is developed with python 3.10.

## Setup dependencies

Expand Down
16 changes: 8 additions & 8 deletions adv/eva/eva_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.probe as cnfg
from cnfg.ihyp import *
Expand Down Expand Up @@ -40,8 +40,8 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False):
with torch.no_grad():
for i in tqdm(range(nd), mininterval=tqdm_mininterval):
bid = str(i)
seq_batch = torch.from_numpy(src_grp[bid][:])
seq_o = torch.from_numpy(tgt_grp[bid][:])
seq_batch = torch.from_numpy(src_grp[bid][()])
seq_o = torch.from_numpy(tgt_grp[bid][()])
lo = seq_o.size(1) - ind_shift
if mv_device:
seq_batch = seq_batch.to(mv_device)
Expand All @@ -65,10 +65,10 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False):
w = float(w)
return sum_loss / w, (w - r) / w * 100.0

td = h5py.File(sys.argv[1], "r")
td = h5File(sys.argv[1], "r")

ntest = td["ndata"][:].item()
nword = td["nword"][:].tolist()
ntest = td["ndata"][()].item()
nword = td["nword"][()].tolist()
nwordi, nwordt = nword[0], nword[-1]

mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd)
Expand All @@ -83,7 +83,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False):

mymodel.eval()

lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes)
lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction="sum", forbidden_index=cnfg.forbidden_indexes)

use_cuda = cnfg.use_cuda
gpuid = cnfg.gpuid
Expand Down
14 changes: 7 additions & 7 deletions adv/predict/predict_ape.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.base as cnfg
from cnfg.ihyp import *
Expand All @@ -24,10 +24,10 @@ def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(cnfg.test_data, "r")
td = h5File(cnfg.test_data, "r")

ntest = td["ndata"][:].item()
nwordi = td["nword"][:].tolist()[0]
ntest = td["ndata"][()].item()
nwordi = td["nword"][()].tolist()[0]
vcbt, nwordt = ldvocab(sys.argv[2])
vcbt = reverse_dict(vcbt)

Expand Down Expand Up @@ -70,8 +70,8 @@ def load_fixing(module):
src_grp, mt_grp = td["src"], td["tgt"]
with open(sys.argv[1], "wb") as f, torch.no_grad():
for i in tqdm(range(ntest), mininterval=tqdm_mininterval):
seq_batch = torch.from_numpy(src_grp[str(i)][:])
seq_mt = torch.from_numpy(mt_grp[str(i)][:])
seq_batch = torch.from_numpy(src_grp[str(i)][()])
seq_mt = torch.from_numpy(mt_grp[str(i)][()])
if cuda_device:
seq_batch = seq_batch.to(cuda_device)
seq_mt = seq_mt.to(cuda_device)
Expand Down
12 changes: 6 additions & 6 deletions adv/predict/predict_doc_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.docpara as cnfg
from cnfg.ihyp import *
Expand All @@ -24,10 +24,10 @@ def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(cnfg.test_data, "r")
td = h5File(cnfg.test_data, "r")

tl = [(str(nsent), str(_curd),) for nsent, ndata in zip(td["nsent"][:].tolist(), td["ndata"][:].tolist()) for _curd in range(ndata)]
nwordi = td["nword"][:].tolist()[0]
tl = [(str(nsent), str(_curd),) for nsent, ndata in zip(td["nsent"][()].tolist(), td["ndata"][()].tolist()) for _curd in range(ndata)]
nwordi = td["nword"][()].tolist()[0]
vcbt, nwordt = ldvocab(sys.argv[2])
vcbt = reverse_dict(vcbt)

Expand Down Expand Up @@ -71,7 +71,7 @@ def load_fixing(module):
src_grp = td["src"]
with open(sys.argv[1], "wb") as f, torch.no_grad():
for nsent, i_d in tqdm(tl, mininterval=tqdm_mininterval):
seq_batch = torch.from_numpy(src_grp[nsent][i_d][:])
seq_batch = torch.from_numpy(src_grp[nsent][i_d][()])
if cuda_device:
seq_batch = seq_batch.to(cuda_device)
seq_batch = seq_batch.long()
Expand Down
14 changes: 7 additions & 7 deletions adv/predict/predict_mulang.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.mulang as cnfg
from cnfg.ihyp import *
Expand All @@ -24,10 +24,10 @@ def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(cnfg.test_data, "r")
td = h5File(cnfg.test_data, "r")

ntest = td["ndata"][:].tolist()
nwordi, ntask = td["nword"][:].tolist()
ntest = td["ndata"][()].tolist()
nwordi, ntask = td["nword"][()].tolist()
vcbt, nwordt = ldvocab(sys.argv[2])
vcbt = reverse_dict(vcbt)

Expand Down Expand Up @@ -65,11 +65,11 @@ def load_fixing(module):

ens = "\n".encode("utf-8")

ntest = [(str(i), _task,) for _nd, _task in zip(ntest, td["taskorder"][:].tolist()) for i in range(_nd)]
ntest = [(str(i), _task,) for _nd, _task in zip(ntest, td["taskorder"][()].tolist()) for i in range(_nd)]

with open(sys.argv[1], "wb") as f, torch.no_grad():
for i_d, taskid in tqdm(ntest, mininterval=tqdm_mininterval):
seq_batch = torch.from_numpy(td[str(taskid)]["src"][i_d][:])
seq_batch = torch.from_numpy(td[str(taskid)]["src"][i_d][()])
if cuda_device:
seq_batch = seq_batch.to(cuda_device)
seq_batch = seq_batch.long()
Expand Down
12 changes: 6 additions & 6 deletions adv/predict/predict_probe_enc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.probe as cnfg
from cnfg.ihyp import *
Expand All @@ -24,10 +24,10 @@ def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(sys.argv[1], "r")
td = h5File(sys.argv[1], "r")

ntest = td["ndata"][:].item()
nwordi = td["nword"][:].tolist()[0]
ntest = td["ndata"][()].item()
nwordi = td["nword"][()].tolist()[0]
vcbt, nwordt = ldvocab(sys.argv[3])
vcbt = reverse_dict(vcbt)

Expand Down Expand Up @@ -59,7 +59,7 @@ def load_fixing(module):
with open(sys.argv[4], "wb") as fwrt, torch.no_grad():
for i in tqdm(range(ntest), mininterval=tqdm_mininterval):
bid = str(i)
seq_batch = torch.from_numpy(src_grp[bid][:])
seq_batch = torch.from_numpy(src_grp[bid][()])
if cuda_device:
seq_batch = seq_batch.to(cuda_device)
seq_batch = seq_batch.long()
Expand Down
20 changes: 10 additions & 10 deletions adv/rank/doc/para/rank_loss_para.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.docpara as cnfg
from cnfg.ihyp import *
Expand All @@ -31,10 +31,10 @@ def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(sys.argv[2], "r")
td = h5File(sys.argv[2], "r")

ntest = td["ndata"][:].item()
nword = td["nword"][:].tolist()
ntest = td["ndata"][()].item()
nword = td["nword"][()].tolist()
nwordi, nwordt = nword[0], nword[-1]

if len(sys.argv) == 4:
Expand All @@ -56,7 +56,7 @@ def load_fixing(module):

mymodel.eval()

lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='none', forbidden_index=cnfg.forbidden_indexes)
lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction="none", forbidden_index=cnfg.forbidden_indexes)

use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid)
use_amp = cnfg.use_amp and use_cuda
Expand All @@ -79,8 +79,8 @@ def load_fixing(module):
with open(sys.argv[1], "wb") as f, torch.no_grad():
for i in tqdm(range(ntest), mininterval=tqdm_mininterval):
_curid = str(i)
seq_batch = torch.from_numpy(src_grp[_curid][:])
seq_o = torch.from_numpy(tgt_grp[_curid][:])
seq_batch = torch.from_numpy(src_grp[_curid][()])
seq_o = torch.from_numpy(tgt_grp[_curid][()])
lo = seq_o.size(-1) - 1
if cuda_device:
seq_batch = seq_batch.to(cuda_device)
Expand All @@ -93,12 +93,12 @@ def load_fixing(module):
ot = seq_o.narrow(-1, 1, lo).contiguous()
with autocast(enabled=use_amp):
output = mymodel(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()).view(bsize, _nsent_use, lo, -1)
loss = lossf(output, ot).sum(-1).view(bsize, -1).sum(-1)
loss = lossf(output, ot).view(bsize, -1).sum(-1)
if norm_token:
lenv = ot.ne(pad_id).int().view(bsize, -1).sum(-1).to(loss)
loss = loss / lenv
f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8"))
loss = output = ot = seq_batch = seq_o = None
f.write(ens)
loss = output = ot = seq_batch = seq_o = None

td.close()
20 changes: 10 additions & 10 deletions adv/rank/doc/rank_loss_sent.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

import torch

from tqdm import tqdm
from utils.tqdm import tqdm

import h5py
from utils.h5serial import h5File

import cnfg.base as cnfg
from cnfg.ihyp import *
Expand All @@ -31,10 +31,10 @@ def load_fixing(module):
if hasattr(module, "fix_load"):
module.fix_load()

td = h5py.File(sys.argv[2], "r")
td = h5File(sys.argv[2], "r")

ntest = td["ndata"][:].item()
nword = td["nword"][:].tolist()
ntest = td["ndata"][()].item()
nword = td["nword"][()].tolist()
nwordi, nwordt = nword[0], nword[-1]

if len(sys.argv) == 4:
Expand All @@ -56,7 +56,7 @@ def load_fixing(module):

mymodel.eval()

lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='none', forbidden_index=cnfg.forbidden_indexes)
lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction="none", forbidden_index=cnfg.forbidden_indexes)

use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid)
use_amp = cnfg.use_amp and use_cuda
Expand All @@ -77,8 +77,8 @@ def load_fixing(module):
with open(sys.argv[1], "wb") as f, torch.no_grad():
for i in tqdm(range(ntest), mininterval=tqdm_mininterval):
_curid = str(i)
seq_batch = torch.from_numpy(src_grp[_curid][:])
seq_o = torch.from_numpy(tgt_grp[_curid][:])
seq_batch = torch.from_numpy(src_grp[_curid][()])
seq_o = torch.from_numpy(tgt_grp[_curid][()])
bsize, nsent = seq_batch.size()[:2]
ebsize = bsize * nsent
if cuda_device:
Expand All @@ -89,12 +89,12 @@ def load_fixing(module):
ot = seq_o.narrow(-1, 1, lo).contiguous()
with autocast(enabled=use_amp):
output = mymodel(seq_batch.view(ebsize, -1), seq_o.narrow(-1, 0, lo).contiguous().view(ebsize, -1)).view(bsize, nsent, lo, -1)
loss = lossf(output, ot).sum(-1).view(bsize, -1).sum(-1)
loss = lossf(output, ot).view(bsize, -1).sum(-1)
if norm_token:
lenv = ot.ne(pad_id).int().view(bsize, -1).sum(-1).to(loss)
loss = loss / lenv
f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8"))
loss = output = ot = seq_batch = seq_o = None
f.write(ens)
loss = output = ot = seq_batch = seq_o = None

td.close()
Loading

0 comments on commit f3e4369

Please sign in to comment.