-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotutils.py
124 lines (113 loc) · 6.42 KB
/
plotutils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# importing all the necessary packages
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
# importing the style package
from matplotlib import style
# using the style for the plot
# plt.style.use('seaborn-pastel')
# creating plot
class Allplots():
def __init__(self):
super(Allplots,self).__init__()
def plot_losses(self, train_losses, val_losses):
plt.plot(train_losses, linestyle="-", linewidth=5, label='train')
plt.plot(val_losses, linestyle="-", linewidth=5, label = 'validation')
plt.legend()
plt.xlabel('Iterations')
plt.ylabel('Loss')
# show plot
plt.show()
def plot_examples(self, unet, datax, datay, num_examples=3, ishybrid=False):
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
m = datax.shape[0]
if ishybrid == True:
for row_num in range(num_examples):
image_indx = np.random.randint(m)
image_arr,_ = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
image_arr = image_arr.squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,:].astype(int))
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.50)[0,:,:].astype(int)),cmap='gray')
ax[row_num][1].set_title("Segmented Image localization")
ax[row_num][2].imshow(np.transpose(datay[image_indx], (1,2,0))[:,:,0],cmap='gray')
ax[row_num][2].set_title("Target image")
else:
for row_num in range(num_examples):
image_indx = np.random.randint(m)
image_arr = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
image_arr = image_arr.squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,:].astype(int))
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.50)[0,:,:].astype(int)),cmap='gray')
ax[row_num][1].set_title("Segmented Image localization")
ax[row_num][2].imshow(np.transpose(datay[image_indx], (1,2,0))[:,:,0],cmap='gray')
ax[row_num][2].set_title("Target image")
plt.show()
def plot_best(self, unet, datax, datay=None, indx=None, index_ranks=None, ishybrid=False):
if ishybrid == True:
if datay is not None:
num_examples = len(indx)
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = indx[row_num]
if index_ranks[row_num] == 0:
image_arr, _ = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
elif index_ranks[row_num] == 1:
_, image_arr = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
image_arr = image_arr.squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,:].astype(int))
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.50)[0,:,:].astype(int)),cmap='gray')
ax[row_num][1].set_title("Segmented Image localization")
ax[row_num][2].imshow(np.transpose(datay[image_indx], (1,2,0))[:,:,0],cmap='gray')
ax[row_num][2].set_title("Target image")
plt.show()
else:
num_examples = len(indx)
fig, ax = plt.subplots(nrows=num_examples, ncols=2, figsize=(12,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = indx[row_num]
if index_ranks[row_num] == 0:
image_arr, _ = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
elif index_ranks[row_num] == 1:
_, image_arr = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
image_arr = image_arr.squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,:].astype(int))
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.50)[0,:,:].astype(int)),cmap='gray')
ax[row_num][1].set_title("Segmented Image localization")
plt.show()
else:
if datay is not None:
num_examples = len(indx)
fig, ax = plt.subplots(nrows=num_examples, ncols=3, figsize=(18,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = indx[row_num]
image_arr = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
image_arr = image_arr.squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,:].astype(int))
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.50)[0,:,:].astype(int)),cmap='gray')
ax[row_num][1].set_title("Segmented Image localization")
ax[row_num][2].imshow(np.transpose(datay[image_indx], (1,2,0))[:,:,0],cmap='gray')
ax[row_num][2].set_title("Target image")
plt.show()
else:
num_examples = len(indx)
fig, ax = plt.subplots(nrows=num_examples, ncols=2, figsize=(12,4*num_examples))
m = datax.shape[0]
for row_num in range(num_examples):
image_indx = indx[row_num]
image_arr = unet(torch.from_numpy(datax[image_indx:image_indx+1]).float().cuda())
image_arr = image_arr.squeeze(0).detach().cpu().numpy()
ax[row_num][0].imshow(np.transpose(datax[image_indx], (1,2,0))[:,:,:].astype(int))
ax[row_num][0].set_title("Orignal Image")
ax[row_num][1].imshow(np.squeeze((image_arr > 0.50)[0,:,:].astype(int)),cmap='gray')
ax[row_num][1].set_title("Segmented Image localization")
plt.show()