-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutil.py
125 lines (94 loc) · 3.37 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import numpy as np
import datetime
import time
import data
#from objax.functional.loss import cross_entropy_logits_sparse
import jax.numpy as jnp
from jax import pmap, host_id, jit
from jax.tree_util import tree_map
def DTS():
return datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
def ensure_dir_exists(directory):
if not os.path.exists(directory):
os.makedirs(directory)
def ensure_dir_exists_for_file(fname):
ensure_dir_exists(os.path.dirname(fname))
def shard(x):
# pmap x across first axis
return pmap(lambda v: v)(x)
def replicate(x, replicas=8):
# replicate x and then shard
replicated = jnp.stack([x] * replicas)
return shard(replicated)
def shapes_of(pytree):
# rebuild a pytree swapping actual params for just shape and type
return tree_map(lambda v: (v.shape, type(v)), pytree)
def reshape_leading_axis(x, s, from_axis=1):
return x.reshape((*s, *x.shape[from_axis:]))
def primary_host():
return host_id() == 0
class EarlyStopping(object):
def __init__(self, patience=3, burn_in=5, max_runtime=None,
smoothing=0.0):
# smoothing = 0.0 => no smoothing
self.original_patience = patience
self.patience = patience
self.burn_in = burn_in
self.lowest_value = None
self.decided_to_stop = False
if max_runtime is not None:
self.exit_time = time.time() + max_runtime
else:
self.exit_time = None
if smoothing < 0.0 or smoothing > 1.0:
raise Exception("invalid smoothing value %s" % smoothing)
self.smoothing = 1.0 - smoothing
self.smoothed_value = None
def should_stop(self, value):
# if we've already decided to stop then return True immediately
if self.decided_to_stop:
return True
# calc smoothed value
if self.smoothed_value is None:
self.smoothed_value = value
else:
self.smoothed_value += self.smoothing * \
(value - self.smoothed_value)
# run taken too long?
if self.exit_time is not None:
if time.time() > self.exit_time:
self.decided_to_stop = True
return True
# ignore first burn_in iterations
if self.burn_in > 0:
self.burn_in -= 1
return False
# take very first value we see as the lowest
if self.lowest_value is None:
self.lowest_value = self.smoothed_value
# check if we've made an improvement; if so reset patience and record
# new lowest
made_improvement = self.smoothed_value < self.lowest_value
if made_improvement:
self.patience = self.original_patience
self.lowest_value = self.smoothed_value
return False
# if no improvement made reduce patience. if no more patience exit.
self.patience -= 1
if self.patience == 0:
self.decided_to_stop = True
return True
else:
return False
def stopped(self):
return self.decided_to_stop
def accuracy(predict_fn, dataset):
num_correct = 0
num_total = 0
for imgs, labels in dataset:
predictions = predict_fn(imgs)
num_correct += jnp.sum(predictions == labels)
num_total += len(labels)
accuracy = num_correct / num_total
return accuracy