-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathconll_utils.lua
79 lines (69 loc) · 1.66 KB
/
conll_utils.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
require 'torch'
module('conll_utils', package.seeall)
cuda = false
function save_model(net, prefix, epoch)
local save_model_name = prefix.."-best-f1"..".torch"
torch.save(save_model_name, net)
end
function shallowcopy(orig)
local orig_type = type(orig)
local copy
if orig_type == 'table' then
copy = {}
for orig_key, orig_value in pairs(orig) do
copy[orig_key] = orig_value
end
else -- number, string, boolean, etc
copy = orig
end
return copy
end
function deepcopy(orig)
local orig_type = type(orig)
local copy
if orig_type == 'table' then
copy = {}
for orig_key, orig_value in next, orig, nil do
copy[deepcopy(orig_key)] = deepcopy(orig_value)
end
setmetatable(copy, deepcopy(getmetatable(orig)))
elseif orig_type == 'userdata' then
copy = orig:clone()
else -- number, string, boolean, etc
copy = orig
end
return copy
end
function subrange(t, first, last)
local sub = {}
for i=first,last do
sub[#sub + 1] = t[i]
end
return sub
end
function copytable(mytable)
newtable = {}
for k, v in pairs(mytable) do
newtable[k] = v
end
return newtable
end
function get_set_from_files(files)
local set = nil
for _, file in pairs(files) do
if set == nil then
set = torch.load(file)
else
local new_set = torch.load(file)
for _, v in ipairs(new_set.data) do
table.insert(set.data, v)
end
for _, v in ipairs(new_set.labels) do
table.insert(set.labels, v)
end
end
end
print('union set size is ', table.getn(set), table.getn(set.data), table.getn(set.labels))
return set
end
function to_cuda(x) return cuda and x:cuda() or x end