-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtoolbox.py
36 lines (29 loc) · 829 Bytes
/
toolbox.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
import math
def extract_mnist(data, img_shape = (28, 28, 1), label = None):
(x_train, y_train), (x_test, y_test) = data.load_data()
if label:
x_train = x_train[(y_train.reshape(-1) == label)]
x_test = x_test[(y_test.reshape(-1) == label)]
x_train = np.concatenate([x_train, x_test])
# Reshaping
x_train = x_train.reshape((
x_train.shape[0],
img_shape[0],
img_shape[1],
img_shape[2])).astype(np.float32)
# Normalization to [-1, 1]
x_train = x_train / 127.5 - 1.0
return x_train
def format_time(time):
formats = ["s", "m", "h", "j", "w", "M", "y"]
nb = np.array([1, 60, 60, 24, 7, 4, 12])
prod = nb.prod()
s = ""
for i in range(nb.shape[0])[::-1]:
if time >= prod:
res = math.floor(time / prod)
time %= prod
s += f"{res}{formats[i]} "
prod /= nb[i]
return s.rstrip()