forked from xinntao/Real-ESRGAN
-
Notifications
You must be signed in to change notification settings - Fork 419
/
Copy pathtest_model.py
126 lines (112 loc) · 4.75 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import torch
import yaml
from basicsr.archs.rrdbnet_arch import RRDBNet
from basicsr.data.paired_image_dataset import PairedImageDataset
from basicsr.losses.losses import GANLoss, L1Loss, PerceptualLoss
from realesrgan.archs.discriminator_arch import UNetDiscriminatorSN
from realesrgan.models.realesrgan_model import RealESRGANModel
from realesrgan.models.realesrnet_model import RealESRNetModel
def test_realesrnet_model():
with open('tests/data/test_realesrnet_model.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
# build model
model = RealESRNetModel(opt)
# test attributes
assert model.__class__.__name__ == 'RealESRNetModel'
assert isinstance(model.net_g, RRDBNet)
assert isinstance(model.cri_pix, L1Loss)
assert isinstance(model.optimizers[0], torch.optim.Adam)
# prepare data
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
model.feed_data(data)
# check dequeue
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# change probability to test if-else
model.opt['gaussian_noise_prob'] = 0
model.opt['gray_noise_prob'] = 0
model.opt['second_blur_prob'] = 0
model.opt['gaussian_noise_prob2'] = 0
model.opt['gray_noise_prob2'] = 0
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# ----------------- test nondist_validation -------------------- #
# construct dataloader
dataset_opt = dict(
name='Demo',
dataroot_gt='tests/data/gt',
dataroot_lq='tests/data/lq',
io_backend=dict(type='disk'),
scale=4,
phase='val')
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True
model.nondist_validation(dataloader, 1, None, False)
assert model.is_train is True
def test_realesrgan_model():
with open('tests/data/test_realesrgan_model.yml', mode='r') as f:
opt = yaml.load(f, Loader=yaml.FullLoader)
# build model
model = RealESRGANModel(opt)
# test attributes
assert model.__class__.__name__ == 'RealESRGANModel'
assert isinstance(model.net_g, RRDBNet) # generator
assert isinstance(model.net_d, UNetDiscriminatorSN) # discriminator
assert isinstance(model.cri_pix, L1Loss)
assert isinstance(model.cri_perceptual, PerceptualLoss)
assert isinstance(model.cri_gan, GANLoss)
assert isinstance(model.optimizers[0], torch.optim.Adam)
assert isinstance(model.optimizers[1], torch.optim.Adam)
# prepare data
gt = torch.rand((1, 3, 32, 32), dtype=torch.float32)
kernel1 = torch.rand((1, 5, 5), dtype=torch.float32)
kernel2 = torch.rand((1, 5, 5), dtype=torch.float32)
sinc_kernel = torch.rand((1, 5, 5), dtype=torch.float32)
data = dict(gt=gt, kernel1=kernel1, kernel2=kernel2, sinc_kernel=sinc_kernel)
model.feed_data(data)
# check dequeue
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# change probability to test if-else
model.opt['gaussian_noise_prob'] = 0
model.opt['gray_noise_prob'] = 0
model.opt['second_blur_prob'] = 0
model.opt['gaussian_noise_prob2'] = 0
model.opt['gray_noise_prob2'] = 0
model.feed_data(data)
# check data shape
assert model.lq.shape == (1, 3, 8, 8)
assert model.gt.shape == (1, 3, 32, 32)
# ----------------- test nondist_validation -------------------- #
# construct dataloader
dataset_opt = dict(
name='Demo',
dataroot_gt='tests/data/gt',
dataroot_lq='tests/data/lq',
io_backend=dict(type='disk'),
scale=4,
phase='val')
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True
model.nondist_validation(dataloader, 1, None, False)
assert model.is_train is True
# ----------------- test optimize_parameters -------------------- #
model.feed_data(data)
model.optimize_parameters(1)
assert model.output.shape == (1, 3, 32, 32)
assert isinstance(model.log_dict, dict)
# check returned keys
expected_keys = ['l_g_pix', 'l_g_percep', 'l_g_gan', 'l_d_real', 'out_d_real', 'l_d_fake', 'out_d_fake']
assert set(expected_keys).issubset(set(model.log_dict.keys()))