diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..206bea1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.*/ +_*/ diff --git a/checkpoints/README.md b/checkpoints/README.md new file mode 100644 index 0000000..a5e6e07 --- /dev/null +++ b/checkpoints/README.md @@ -0,0 +1 @@ +Put pretrained models. diff --git a/model.py b/model.py new file mode 100644 index 0000000..6989301 --- /dev/null +++ b/model.py @@ -0,0 +1,103 @@ +import torch +import torch.nn as nn + + +class UNetSeeInDark(nn.Module): + def __init__(self, in_channels=4, out_channels=4): + super(UNetSeeInDark, self).__init__() + + # device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + self.conv1_1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1) + self.conv1_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2_1 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + self.conv3_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) + self.conv3_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + self.pool3 = nn.MaxPool2d(kernel_size=2) + + self.conv4_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) + self.conv4_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + self.pool4 = nn.MaxPool2d(kernel_size=2) + + self.conv5_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) + self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) + + self.upv6 = nn.ConvTranspose2d(512, 256, 2, stride=2) + self.conv6_1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1) + self.conv6_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) + + self.upv7 = nn.ConvTranspose2d(256, 128, 2, stride=2) + self.conv7_1 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1) + self.conv7_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) + + self.upv8 = nn.ConvTranspose2d(128, 64, 2, stride=2) + self.conv8_1 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1) + self.conv8_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) + + self.upv9 = nn.ConvTranspose2d(64, 32, 2, stride=2) + self.conv9_1 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1) + self.conv9_2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1) + + self.conv10_1 = nn.Conv2d(32, out_channels, kernel_size=1, stride=1) + + def forward(self, x): + conv1 = self.lrelu(self.conv1_1(x)) + conv1 = self.lrelu(self.conv1_2(conv1)) + pool1 = self.pool1(conv1) + + conv2 = self.lrelu(self.conv2_1(pool1)) + conv2 = self.lrelu(self.conv2_2(conv2)) + pool2 = self.pool1(conv2) + + conv3 = self.lrelu(self.conv3_1(pool2)) + conv3 = self.lrelu(self.conv3_2(conv3)) + pool3 = self.pool1(conv3) + + conv4 = self.lrelu(self.conv4_1(pool3)) + conv4 = self.lrelu(self.conv4_2(conv4)) + pool4 = self.pool1(conv4) + + conv5 = self.lrelu(self.conv5_1(pool4)) + conv5 = self.lrelu(self.conv5_2(conv5)) + + up6 = self.upv6(conv5) + up6 = torch.cat([up6, conv4], 1) + conv6 = self.lrelu(self.conv6_1(up6)) + conv6 = self.lrelu(self.conv6_2(conv6)) + + up7 = self.upv7(conv6) + up7 = torch.cat([up7, conv3], 1) + conv7 = self.lrelu(self.conv7_1(up7)) + conv7 = self.lrelu(self.conv7_2(conv7)) + + up8 = self.upv8(conv7) + up8 = torch.cat([up8, conv2], 1) + conv8 = self.lrelu(self.conv8_1(up8)) + conv8 = self.lrelu(self.conv8_2(conv8)) + + up9 = self.upv9(conv8) + up9 = torch.cat([up9, conv1], 1) + conv9 = self.lrelu(self.conv9_1(up9)) + conv9 = self.lrelu(self.conv9_2(conv9)) + + conv10 = self.conv10_1(conv9) + # out = nn.functional.pixel_shuffle(conv10, 2) + out = conv10 + return out + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + m.weight.data.normal_(0.0, 0.02) + if m.bias is not None: + m.bias.data.normal_(0.0, 0.02) + if isinstance(m, nn.ConvTranspose2d): + m.weight.data.normal_(0.0, 0.02) + + def lrelu(self, x): + outt = torch.max(0.2 * x, x) + return outt diff --git a/test.py b/test.py new file mode 100644 index 0000000..f3bff42 --- /dev/null +++ b/test.py @@ -0,0 +1,92 @@ +import os +import argparse +import torch + +import numpy as np +import torch.nn.functional as F +import scipy.io as sio + +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + +import utils +from model import UNetSeeInDark + + +def forward_patches(model, noisy, patch_size=256 * 3, pad=32): + shift = patch_size - pad * 2 + + noisy = torch.FloatTensor(noisy).cuda() + noisy = utils.raw2stack(noisy).unsqueeze(0) + noisy = F.pad(noisy, (pad, pad, pad, pad), mode='reflect') + denoised = torch.zeros_like(noisy) + + _, _, H, W = noisy.shape + for i in np.arange(0, H, shift): + for j in np.arange(0, W, shift): + h_end, w_end = min(i + patch_size, H), min(j + patch_size, W) + h_start, w_start = h_end - patch_size, w_end - patch_size + + input_var = noisy[..., h_start: h_end, w_start: w_end] + with torch.no_grad(): + out_var = model(input_var) + denoised[..., h_start + pad: h_end - pad, w_start + pad: w_end - pad] = \ + out_var[..., pad:-pad, pad:-pad] + + denoised = denoised[..., pad:-pad, pad:-pad] + denoised = utils.stack2raw(denoised[0]).detach().cpu().numpy() + + denoised = denoised.clip(0, 1) + return denoised + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--root', default='/mnt/lustre/zhangyi3/data/SIDD_Medium/Data/') + parser.add_argument('--camera', choices=['s6', 'gp', 'ip'], required=True, help='camera name') + args = parser.parse_args() + + camera = args.camera + root = args.root + + # save_dir = './results/' + camera + # if not os.path.exists(save_dir): + # os.makedirs(save_dir) + print('test', camera, 'root', root) + + test_data_list = [item for item in os.listdir(root) if int(item.split('_')[1]) in [2, 3, 5] and camera in item.lower()] + + # build model + model = UNetSeeInDark() + model = model.cuda() + model = torch.nn.DataParallel(model) + + model_path = './checkpoints/%s.pth' % camera.lower() + model.load_state_dict(torch.load(model_path, map_location='cpu')) + + psnr_list = [] + for idx, item in enumerate(test_data_list): + head = item[:4] + for tail in ['GT_RAW_010', 'GT_RAW_011']: + print('processing', idx, item, tail, end=' ') + mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail))) + gt = np.array(mat['x'], dtype=np.float32) + mat = utils.open_hdf5(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'NOISY')))) + noisy = np.array(mat['x'], dtype=np.float32) + + meta = sio.loadmat(os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA')))) + meta = meta['metadata'][0][0] + + # transform to rggb pattern + py_meta = utils.extract_metainfo( + os.path.join(root, item, '%s_%s.MAT' % (head, tail.replace('GT', 'METADATA')))) + pattern = py_meta['pattern'] + noisy = utils.transform_to_rggb(noisy, pattern) + gt = utils.transform_to_rggb(gt, pattern) + + denoised = forward_patches(model, noisy) + + psnr = peak_signal_noise_ratio(gt, denoised, data_range=1) + psnr_list.append(psnr) + print('psnr %.2f' % psnr) + + print('Camera %s, average PSNR %.2f' % (camera, np.mean(psnr_list))) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f2eb508 --- /dev/null +++ b/utils.py @@ -0,0 +1,105 @@ +import h5py +import time +import torch + +import scipy.io as sio + +import numpy as np + + +def open_hdf5(filename): + while True: + try: + hdf5_file = h5py.File(filename, 'r') + return hdf5_file + except OSError: + print(filename, ' waiting') + time.sleep(3) # Wait a bit + + +def extract_metainfo(path='0151_METADATA_RAW_010.MAT'): + meta = sio.loadmat(path)['metadata'] + mat_vals = meta[0][0] + mat_keys = mat_vals.dtype.descr + + keys = [] + for item in mat_keys: + keys.append(item[0]) + + py_dict = {} + for key in keys: + py_dict[key] = mat_vals[key] + + device = py_dict['Model'][0].lower() + bitDepth = py_dict['BitDepth'][0][0] + if 'iphone' in device or bitDepth != 16: + noise = py_dict['UnknownTags'][-2][0][-1][0][:2] + iso = py_dict['DigitalCamera'][0, 0]['ISOSpeedRatings'][0][0] + pattern = py_dict['SubIFDs'][0][0]['UnknownTags'][0][0][1][0][-1][0] + time = py_dict['DigitalCamera'][0, 0]['ExposureTime'][0][0] + + else: + noise = py_dict['UnknownTags'][-1][0][-1][0][:2] + iso = py_dict['ISOSpeedRatings'][0][0] + pattern = py_dict['UnknownTags'][1][0][-1][0] + time = py_dict['ExposureTime'][0][0] # the 0th row and 0th line item + + rgb = ['R', 'G', 'B'] + pattern = ''.join([rgb[i] for i in pattern]) + + asShotNeutral = py_dict['AsShotNeutral'][0] + b_gain, _, r_gain = asShotNeutral + + # only load ccm1 + ccm = py_dict['ColorMatrix1'][0].astype(float).reshape((3, 3)) + + return {'device': device, + 'pattern': pattern, + 'iso': iso, + 'noise': noise, + 'time': time, + 'wb': np.array([r_gain, 1, b_gain]), + 'ccm': ccm, } + + +def transform_to_rggb(img, pattern): + assert len(img.shape) == 2 and type(img) == np.ndarray + + if pattern.lower() == 'bggr': # same pattern + img = np.roll(np.roll(img, 1, axis=1), 1, axis=0) + elif pattern.lower() == 'rggb': + pass + elif pattern.lower() == 'grbg': + img = np.roll(img, 1, axis=1) + elif pattern.lower() == 'gbrg': + img = np.roll(img, 1, axis=0) + else: + assert 'no support' + + return img + + +def raw2stack(var): + h, w = var.shape + if var.is_cuda: + res = torch.cuda.FloatTensor(4, h // 2, w // 2).fill_(0) + else: + res = torch.FloatTensor(4, h // 2, w // 2).fill_(0) + res[0] = var[0::2, 0::2] + res[1] = var[0::2, 1::2] + res[2] = var[1::2, 0::2] + res[3] = var[1::2, 1::2] + return res + + +def stack2raw(var): + _, h, w = var.shape + if var.is_cuda: + res = torch.cuda.FloatTensor(h * 2, w * 2) + else: + res = torch.FloatTensor(h * 2, w * 2) + res[0::2, 0::2] = var[0] + res[0::2, 1::2] = var[1] + res[1::2, 0::2] = var[2] + res[1::2, 1::2] = var[3] + return res