Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jiexiaou committed Nov 14, 2022
1 parent 51ecc50 commit 5fad7b1
Show file tree
Hide file tree
Showing 21 changed files with 2,634 additions and 2 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
# Stochastic Window Transformer for Image Restoration (NeurIPS 2022)
<b>Jie Xiao, <a href='https://xueyangfu.github.io'>Xueyang Fu</a>, Zheng-Jun Zha, Feng Wu</b>

> **Abstract:** *Thanks to the strong representation ability, transformers have attained impressive results for image restoration. However, existing transformers do not carefully take into account the particularities of image restoration. Basically, image restoration requires that the ideal approach should be invariant to translation of degradation, i.e., undesirable degradation should be removed irrespective of its position within the image. Moreover, local relationships play a vital role and should be faithfully exploited for recovering clean images. Nevertheless, most of transformers have resorted to either fixed local window based or global attention, which unfortunately breaks the translation invariance and further causes huge loss of local relationships. To address these issues, we propose an elegant stochastic window strategy for transformers. Specifically, we introduce the window partition with stochastic shift to replace the original fixed window partition for training and elaborate the layer expectation propagation algorithm to efficiently approximate the expectation of the induced stochastic transformer for testing. The stochastic window transformer can not only enjoy powerful representation but also maintain the desired property of translation invariance and locality. Experiments validate the stochastic window strategy consistently improves performance on various image restoration tasks (image deraining, denosing, and deblurring) by significant margins.*
## Method
![Stoformer](figs/method.png)
## The code will be available.
## Train
```
python train_Deblur.py --arch Stoformer --save_dir save_path --train_dir GoPro/train --val_dir GoPro/val --nepoch 600 --embed 32 --checkpoint 100 --optimizer adam --lr_initial 3e-4 --train_workers 4 --env _GoPro --gpu '0,1' --train_ps 256 --batch_size 8 --use_grad_clip
python train_Deblur.py --arch Fixformer --save_dir save_path --train_dir GoPro/train --val_dir GoPro/val --nepoch 600 --embed 32 --checkpoint 100 --optimizer adam --lr_initial 3e-4 --train_workers 4 --env _GoPro --gpu '0,1' --train_ps 256 --batch_size 8 --use_grad_clip
```
## Test
```
python test_Deblur.py --arch Stoformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro++ --weights sto_model_path --batch_size 8 --crop_size 512 --overlap_size 32
python test_Deblur.py --arch Stoformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro-+ --weights fix_model_path --batch_size 8 --crop_size 512 --overlap_size 32
python test_Deblur.py --arch Fixformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro+- --weights sto_model_path --batch_size 8 --crop_size 512 --overlap_size 32
python test_Deblur.py --arch Fixformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro-- --weights fix_model_path --batch_size 8 --crop_size 512 --overlap_size 32
```
## Pretrained Weight
- Deblur: <a href="https://drive.google.com/drive/folders/1SwnBd2VrWwjirDxl1iWVgnYxf_v_4kyI">stochastic_model and fix_model</a>
## Evaluation
- Deblur: <a href="evaluategopro.m">evaluategopro.m</a>
## Acknowledgement
We refer to [Uformer](https://github.com/ZhendongWang6/Uformer) and [Restormer](https://github.com/swz30/Restormer). Thanks for sharing.
## Citation
```
@inproceedings{xiao2022stochastic,
title={Stochastic Window Transformer for Image Restoration},
author={Xiao, Jie and Fu, Xueyang and Wu, Feng and Zha, Zheng-Jun},
booktitle={NeurIPS},
year={2022}
```
## Contact
Please contact us if there is any question([email protected]).
100 changes: 100 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import numpy as np
import os
from torch.utils.data import Dataset
import torch
from utils import is_png_file, load_img, Augment_RGB_torch, load_gray_img
import random
import cv2

augment = Augment_RGB_torch()
transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]

##################################################################################################
def padding(img_lq, gt_size=384):
h, w, _ = img_lq.shape

h_pad = max(0, gt_size - h)
w_pad = max(0, gt_size - w)

if h_pad == 0 and w_pad == 0:
return img_lq

img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
# print('img_lq', img_lq.shape, img_gt.shape)
if img_lq.ndim == 2:
img_lq = np.expand_dims(img_lq, axis=2)
return img_lq


class DataLoaderTrainGoPro(Dataset):
def __init__(self, rgb_dir, patchsize, target_transform=None):
super(DataLoaderTrainGoPro, self).__init__()

self.target_transform = target_transform

gt_dir = 'groundtruth'
input_dir = 'input'

clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
input_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))

self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)]
self.input_filenames = [os.path.join(rgb_dir, input_dir, x) for x in input_files if is_png_file(x)]

self.tar_size = len(self.clean_filenames) # get the size of target
self.crop_size = patchsize

def __len__(self):
return self.tar_size

def __getitem__(self, index):
tar_index = index % self.tar_size
ps = self.crop_size

clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
input = torch.from_numpy(np.float32(load_img(self.input_filenames[tar_index])))

clean = clean.permute(2, 0, 1)
input = input.permute(2, 0, 1)

# Crop Input and Target
H = clean.shape[1]
W = clean.shape[2]

if H - ps == 0:
r = 0
c = 0
else:
r = np.random.randint(0, H - ps)
c = np.random.randint(0, W - ps)
clean = clean[:, r:r + ps, c:c + ps]
input = input[:, r:r + ps, c:c + ps]

apply_trans = transforms_aug[random.getrandbits(3)]

clean = getattr(augment, apply_trans)(clean)
input = getattr(augment, apply_trans)(input)

return [clean, input]

class DataLoaderTest(Dataset):
def __init__(self, input_dir):
super(DataLoaderTest, self).__init__()

noisy_files = sorted(os.listdir(input_dir))
self.noisy_filenames = [os.path.join(input_dir, x) for x in noisy_files if is_png_file(x)]

self.tar_size = len(self.noisy_filenames)

def __len__(self):
return self.tar_size

def __getitem__(self, index):
tar_index = index % self.tar_size

noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]

noisy = noisy.permute(2, 0, 1)

return noisy, noisy_filename
29 changes: 29 additions & 0 deletions evaluategopro.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
close all;clear all;

file_path = strcat('datapath', '/');
gt_path = strcat('groundtruth', '/');
path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.PNG'))];
gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
img_num = length(path_list);
h=waitbar(0, 'Processing!');
total_psnr = 0;
total_ssim = 0;
if img_num > 0
for j = 1:img_num
waitbar(j/img_num);
image_name = path_list(j).name;
gt_name = gt_list(j).name;
input = imread(strcat(file_path,image_name));
gt = imread(strcat(gt_path, gt_name));
ssim_val = ssim(input, gt);
psnr_val = psnr(input, gt);
total_ssim = total_ssim + ssim_val;
total_psnr = total_psnr + psnr_val;
end
end
qm_psnr = total_psnr / img_num;
qm_ssim = total_ssim / img_num;
close(h);
fprintf('For dataset PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim);


Loading

0 comments on commit 5fad7b1

Please sign in to comment.