Skip to content

Commit

Permalink
add evaluation code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyi-3 committed Oct 21, 2021
1 parent cab3650 commit ff5d855
Show file tree
Hide file tree
Showing 5 changed files with 303 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.*/
_*/
1 change: 1 addition & 0 deletions checkpoints/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Put pretrained models.
103 changes: 103 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch
import torch.nn as nn


class UNetSeeInDark(nn.Module):
def __init__(self, in_channels=4, out_channels=4):
super(UNetSeeInDark, self).__init__()

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1)
self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
self.pool1 = nn.MaxPool2d(kernel_size=2)

self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.pool2 = nn.MaxPool2d(kernel_size=2)

self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.pool3 = nn.MaxPool2d(kernel_size=2)

self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.pool4 = nn.MaxPool2d(kernel_size=2)

self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)

self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2)
self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1)
self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)

self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2)
self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1)
self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)

self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2)
self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)

self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2)
self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1)
self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)

self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1)

def forward(self, x):
conv1 = self.lrelu(self.conv1_1(x))
conv1 = self.lrelu(self.conv1_2(conv1))
pool1 = self.pool1(conv1)

conv2 = self.lrelu(self.conv2_1(pool1))
conv2 = self.lrelu(self.conv2_2(conv2))
pool2 = self.pool1(conv2)

conv3 = self.lrelu(self.conv3_1(pool2))
conv3 = self.lrelu(self.conv3_2(conv3))
pool3 = self.pool1(conv3)

conv4 = self.lrelu(self.conv4_1(pool3))
conv4 = self.lrelu(self.conv4_2(conv4))
pool4 = self.pool1(conv4)

conv5 = self.lrelu(self.conv5_1(pool4))
conv5 = self.lrelu(self.conv5_2(conv5))

up6 = self.upv6(conv5)
up6 = torch.cat([up6, conv4], 1)
conv6 = self.lrelu(self.conv6_1(up6))
conv6 = self.lrelu(self.conv6_2(conv6))

up7 = self.upv7(conv6)
up7 = torch.cat([up7, conv3], 1)
conv7 = self.lrelu(self.conv7_1(up7))
conv7 = self.lrelu(self.conv7_2(conv7))

up8 = self.upv8(conv7)
up8 = torch.cat([up8, conv2], 1)
conv8 = self.lrelu(self.conv8_1(up8))
conv8 = self.lrelu(self.conv8_2(conv8))

up9 = self.upv9(conv8)
up9 = torch.cat([up9, conv1], 1)
conv9 = self.lrelu(self.conv9_1(up9))
conv9 = self.lrelu(self.conv9_2(conv9))

conv10 = self.conv10_1(conv9)
# out = nn.functional.pixel_shuffle(conv10, 2)
out = conv10
return out

def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
if m.bias is not None:
m.bias.data.normal_(0.0, 0.02)
if isinstance(m, nn.ConvTranspose2d):
m.weight.data.normal_(0.0, 0.02)

def lrelu(self, x):
outt = torch.max(0.2 * x, x)
return outt
92 changes: 92 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import argparse
import torch

import numpy as np
import torch.nn.functional as F
import scipy.io as sio

from skimage.metrics import peak_signal_noise_ratio, structural_similarity

import utils
from model import UNetSeeInDark


def forward_patches(model, noisy, patch_size=256 * 3, pad=32):
shift = patch_size - pad * 2

noisy = torch.FloatTensor(noisy).cuda()
noisy = utils.raw2stack(noisy).unsqueeze(0)
noisy = F.pad(noisy, (pad, pad, pad, pad), mode='reflect')
denoised = torch.zeros_like(noisy)

_, _, H, W = noisy.shape
for i in np.arange(0, H, shift):
for j in np.arange(0, W, shift):
h_end, w_end = min(i + patch_size, H), min(j + patch_size, W)
h_start, w_start = h_end - patch_size, w_end - patch_size

input_var = noisy[..., h_start: h_end, w_start: w_end]
with torch.no_grad():
out_var = model(input_var)
denoised[..., h_start + pad: h_end - pad, w_start + pad: w_end - pad] = \
out_var[..., pad:-pad, pad:-pad]

denoised = denoised[..., pad:-pad, pad:-pad]
denoised = utils.stack2raw(denoised[0]).detach().cpu().numpy()

denoised = denoised.clip(0, 1)
return denoised


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--root', default='/mnt/lustre/zhangyi3/data/SIDD_Medium/Data/')
parser.add_argument('--camera', choices=['s6', 'gp', 'ip'], required=True, help='camera name')
args = parser.parse_args()

camera = args.camera
root = args.root

# save_dir = './results/' + camera
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
print('test', camera, 'root', root)

test_data_list = [item for item in os.listdir(root) if int(item.split('_')[1]) in [2, 3, 5] and camera in item.lower()]

# build model
model = UNetSeeInDark()
model = model.cuda()
model = torch.nn.DataParallel(model)

model_path = './checkpoints/%s.pth' % camera.lower()
model.load_state_dict(torch.load(model_path, map_location='cpu'))

psnr_list = []
for idx, item in enumerate(test_data_list):
head = item[:4]
for tail in ['GT_RAW_010', 'GT_RAW_011']:
print('processing', idx, item, tail, end=' ')
mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail)))
gt = np.array(mat['x'], dtype=np.float32)
mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'NOISY'))))
noisy = np.array(mat['x'], dtype=np.float32)

meta = sio.loadmat(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA'))))
meta = meta['metadata'][0][0]

# transform to rggb pattern
py_meta = utils.extract_metainfo(
os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA'))))
pattern = py_meta['pattern']
noisy = utils.transform_to_rggb(noisy, pattern)
gt = utils.transform_to_rggb(gt, pattern)

denoised = forward_patches(model, noisy)

psnr = peak_signal_noise_ratio(gt, denoised, data_range=1)
psnr_list.append(psnr)
print('psnr %.2f' % psnr)

print('Camera %s, average PSNR %.2f' % (camera, np.mean(psnr_list)))
105 changes: 105 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import h5py
import time
import torch

import scipy.io as sio

import numpy as np


def open_hdf5(filename):
while True:
try:
hdf5_file = h5py.File(filename, 'r')
return hdf5_file
except OSError:
print(filename, ' waiting')
time.sleep(3) # Wait a bit


def extract_metainfo(path='0151_METADATA_RAW_010.MAT'):
meta = sio.loadmat(path)['metadata']
mat_vals = meta[0][0]
mat_keys = mat_vals.dtype.descr

keys = []
for item in mat_keys:
keys.append(item[0])

py_dict = {}
for key in keys:
py_dict[key] = mat_vals[key]

device = py_dict['Model'][0].lower()
bitDepth = py_dict['BitDepth'][0][0]
if 'iphone' in device or bitDepth != 16:
noise = py_dict['UnknownTags'][-2][0][-1][0][:2]
iso = py_dict['DigitalCamera'][0, 0]['ISOSpeedRatings'][0][0]
pattern = py_dict['SubIFDs'][0][0]['UnknownTags'][0][0][1][0][-1][0]
time = py_dict['DigitalCamera'][0, 0]['ExposureTime'][0][0]

else:
noise = py_dict['UnknownTags'][-1][0][-1][0][:2]
iso = py_dict['ISOSpeedRatings'][0][0]
pattern = py_dict['UnknownTags'][1][0][-1][0]
time = py_dict['ExposureTime'][0][0] # the 0th row and 0th line item

rgb = ['R', 'G', 'B']
pattern = ''.join([rgb[i] for i in pattern])

asShotNeutral = py_dict['AsShotNeutral'][0]
b_gain, _, r_gain = asShotNeutral

# only load ccm1
ccm = py_dict['ColorMatrix1'][0].astype(float).reshape((3, 3))

return {'device': device,
'pattern': pattern,
'iso': iso,
'noise': noise,
'time': time,
'wb': np.array([r_gain, 1, b_gain]),
'ccm': ccm, }


def transform_to_rggb(img, pattern):
assert len(img.shape) == 2 and type(img) == np.ndarray

if pattern.lower() == 'bggr': # same pattern
img = np.roll(np.roll(img, 1, axis=1), 1, axis=0)
elif pattern.lower() == 'rggb':
pass
elif pattern.lower() == 'grbg':
img = np.roll(img, 1, axis=1)
elif pattern.lower() == 'gbrg':
img = np.roll(img, 1, axis=0)
else:
assert 'no support'

return img


def raw2stack(var):
h, w = var.shape
if var.is_cuda:
res = torch.cuda.FloatTensor(4, h // 2, w // 2).fill_(0)
else:
res = torch.FloatTensor(4, h // 2, w // 2).fill_(0)
res[0] = var[0::2, 0::2]
res[1] = var[0::2, 1::2]
res[2] = var[1::2, 0::2]
res[3] = var[1::2, 1::2]
return res


def stack2raw(var):
_, h, w = var.shape
if var.is_cuda:
res = torch.cuda.FloatTensor(h * 2, w * 2)
else:
res = torch.FloatTensor(h * 2, w * 2)
res[0::2, 0::2] = var[0]
res[0::2, 1::2] = var[1]
res[1::2, 0::2] = var[2]
res[1::2, 1::2] = var[3]
return res

0 comments on commit ff5d855

Please sign in to comment.