-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhelper.py
93 lines (74 loc) · 3.5 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import math
from PIL import Image
irange = range
def make_grid(tensor, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
if not (torch.is_tensor(tensor) or
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor)))
# if list of tensors, convert to a 4D mini-batch Tensor
if isinstance(tensor, list):
tensor = torch.stack(tensor, dim=0)
if tensor.dim() == 2: # single image H x W
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
if tensor.dim() == 3: # single image
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
tensor = torch.cat((tensor, tensor, tensor), 0)
tensor = tensor.view(1, tensor.size(0), tensor.size(1), tensor.size(2))
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
tensor = torch.cat((tensor, tensor, tensor), 1)
if normalize is True:
tensor = tensor.clone() # avoid modifying tensor in-place
if range is not None:
assert isinstance(range, tuple), \
"range has to be a tuple (min, max) if specified. min and max are numbers"
def norm_ip(img, min, max):
img.clamp_(min=min, max=max)
img.add_(-min).div_(max - min + 1e-5)
def norm_range(t, range):
if range is not None:
norm_ip(t, range[0], range[1])
else:
norm_ip(t, float(t.min()), float(t.max()))
if scale_each is True:
for t in tensor: # loop over mini-batch dimension
norm_range(t, range)
else:
norm_range(tensor, range)
if tensor.size(0) == 1:
return tensor.squeeze()
# make the mini-batch of images into a grid
nmaps = tensor.size(0)
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
grid = tensor.new_full((3, height * ymaps + padding, width * xmaps + padding), pad_value)
k = 0
for y in irange(ymaps):
for x in irange(xmaps):
if k >= nmaps:
break
grid.narrow(1, y * height + padding, height - padding)\
.narrow(2, x * width + padding, width - padding)\
.copy_(tensor[k])
k = k + 1
return grid
def norm_ip(img, min, max):
img.clamp_(min=min, max=max)
img.add_(-min).div_(max - min + 1e-5)
return img
def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0):
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value,
normalize=normalize, range=range, scale_each=scale_each)
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = grid.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(filename)
def save_single_image(tensor,filename):
img = norm_ip(tensor, float(tensor.min()), float(tensor.max()))
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = img.mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
im = Image.fromarray(ndarr)
im.save(filename)