diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..13566b8
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# Default ignored files
+/shelf/
+/workspace.xml
+# Editor-based HTTP Client requests
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/README.md b/README.md
index 37ffa4f..7968aa6 100644
--- a/README.md
+++ b/README.md
@@ -1,9 +1,33 @@
# Stochastic Window Transformer for Image Restoration (NeurIPS 2022)
Jie Xiao, Xueyang Fu, Zheng-Jun Zha, Feng Wu
-
> **Abstract:** *Thanks to the strong representation ability, transformers have attained impressive results for image restoration. However, existing transformers do not carefully take into account the particularities of image restoration. Basically, image restoration requires that the ideal approach should be invariant to translation of degradation, i.e., undesirable degradation should be removed irrespective of its position within the image. Moreover, local relationships play a vital role and should be faithfully exploited for recovering clean images. Nevertheless, most of transformers have resorted to either fixed local window based or global attention, which unfortunately breaks the translation invariance and further causes huge loss of local relationships. To address these issues, we propose an elegant stochastic window strategy for transformers. Specifically, we introduce the window partition with stochastic shift to replace the original fixed window partition for training and elaborate the layer expectation propagation algorithm to efficiently approximate the expectation of the induced stochastic transformer for testing. The stochastic window transformer can not only enjoy powerful representation but also maintain the desired property of translation invariance and locality. Experiments validate the stochastic window strategy consistently improves performance on various image restoration tasks (image deraining, denosing, and deblurring) by significant margins.*
## Method

-## The code will be available.
+## Train
+```
+python train_Deblur.py --arch Stoformer --save_dir save_path --train_dir GoPro/train --val_dir GoPro/val --nepoch 600 --embed 32 --checkpoint 100 --optimizer adam --lr_initial 3e-4 --train_workers 4 --env _GoPro --gpu '0,1' --train_ps 256 --batch_size 8 --use_grad_clip
+python train_Deblur.py --arch Fixformer --save_dir save_path --train_dir GoPro/train --val_dir GoPro/val --nepoch 600 --embed 32 --checkpoint 100 --optimizer adam --lr_initial 3e-4 --train_workers 4 --env _GoPro --gpu '0,1' --train_ps 256 --batch_size 8 --use_grad_clip
+```
+## Test
+```
+python test_Deblur.py --arch Stoformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro++ --weights sto_model_path --batch_size 8 --crop_size 512 --overlap_size 32
+python test_Deblur.py --arch Stoformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro-+ --weights fix_model_path --batch_size 8 --crop_size 512 --overlap_size 32
+python test_Deblur.py --arch Fixformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro+- --weights sto_model_path --batch_size 8 --crop_size 512 --overlap_size 32
+python test_Deblur.py --arch Fixformer --gpu '0,1' --input_dir GoPro/test/input --embed_dim 32 --result_dir result/GoPro-- --weights fix_model_path --batch_size 8 --crop_size 512 --overlap_size 32
+```
+## Pretrained Weight
+- Deblur: stochastic_model and fix_model
+## Evaluation
+- Deblur: evaluategopro.m
+## Acknowledgement
+We refer to [Uformer](https://github.com/ZhendongWang6/Uformer) and [Restormer](https://github.com/swz30/Restormer). Thanks for sharing.
+## Citation
+```
+@inproceedings{xiao2022stochastic,
+ title={Stochastic Window Transformer for Image Restoration},
+ author={Xiao, Jie and Fu, Xueyang and Wu, Feng and Zha, Zheng-Jun},
+ booktitle={NeurIPS},
+ year={2022}
+```
## Contact
Please contact us if there is any question(ustchbxj@mail.ustc.edu.cn).
\ No newline at end of file
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000..4619ca8
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,100 @@
+import numpy as np
+import os
+from torch.utils.data import Dataset
+import torch
+from utils import is_png_file, load_img, Augment_RGB_torch, load_gray_img
+import random
+import cv2
+
+augment = Augment_RGB_torch()
+transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
+
+##################################################################################################
+def padding(img_lq, gt_size=384):
+ h, w, _ = img_lq.shape
+
+ h_pad = max(0, gt_size - h)
+ w_pad = max(0, gt_size - w)
+
+ if h_pad == 0 and w_pad == 0:
+ return img_lq
+
+ img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
+ # print('img_lq', img_lq.shape, img_gt.shape)
+ if img_lq.ndim == 2:
+ img_lq = np.expand_dims(img_lq, axis=2)
+ return img_lq
+
+
+class DataLoaderTrainGoPro(Dataset):
+ def __init__(self, rgb_dir, patchsize, target_transform=None):
+ super(DataLoaderTrainGoPro, self).__init__()
+
+ self.target_transform = target_transform
+
+ gt_dir = 'groundtruth'
+ input_dir = 'input'
+
+ clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
+ input_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
+
+ self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)]
+ self.input_filenames = [os.path.join(rgb_dir, input_dir, x) for x in input_files if is_png_file(x)]
+
+ self.tar_size = len(self.clean_filenames) # get the size of target
+ self.crop_size = patchsize
+
+ def __len__(self):
+ return self.tar_size
+
+ def __getitem__(self, index):
+ tar_index = index % self.tar_size
+ ps = self.crop_size
+
+ clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
+ input = torch.from_numpy(np.float32(load_img(self.input_filenames[tar_index])))
+
+ clean = clean.permute(2, 0, 1)
+ input = input.permute(2, 0, 1)
+
+ # Crop Input and Target
+ H = clean.shape[1]
+ W = clean.shape[2]
+
+ if H - ps == 0:
+ r = 0
+ c = 0
+ else:
+ r = np.random.randint(0, H - ps)
+ c = np.random.randint(0, W - ps)
+ clean = clean[:, r:r + ps, c:c + ps]
+ input = input[:, r:r + ps, c:c + ps]
+
+ apply_trans = transforms_aug[random.getrandbits(3)]
+
+ clean = getattr(augment, apply_trans)(clean)
+ input = getattr(augment, apply_trans)(input)
+
+ return [clean, input]
+
+class DataLoaderTest(Dataset):
+ def __init__(self, input_dir):
+ super(DataLoaderTest, self).__init__()
+
+ noisy_files = sorted(os.listdir(input_dir))
+ self.noisy_filenames = [os.path.join(input_dir, x) for x in noisy_files if is_png_file(x)]
+
+ self.tar_size = len(self.noisy_filenames)
+
+ def __len__(self):
+ return self.tar_size
+
+ def __getitem__(self, index):
+ tar_index = index % self.tar_size
+
+ noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
+ noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
+
+ noisy = noisy.permute(2, 0, 1)
+
+ return noisy, noisy_filename
\ No newline at end of file
diff --git a/evaluategopro.m b/evaluategopro.m
new file mode 100644
index 0000000..6da88aa
--- /dev/null
+++ b/evaluategopro.m
@@ -0,0 +1,29 @@
+close all;clear all;
+
+file_path = strcat('datapath', '/');
+gt_path = strcat('groundtruth', '/');
+path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.PNG'))];
+gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
+img_num = length(path_list);
+h=waitbar(0, 'Processing!');
+total_psnr = 0;
+total_ssim = 0;
+if img_num > 0
+ for j = 1:img_num
+ waitbar(j/img_num);
+ image_name = path_list(j).name;
+ gt_name = gt_list(j).name;
+ input = imread(strcat(file_path,image_name));
+ gt = imread(strcat(gt_path, gt_name));
+ ssim_val = ssim(input, gt);
+ psnr_val = psnr(input, gt);
+ total_ssim = total_ssim + ssim_val;
+ total_psnr = total_psnr + psnr_val;
+ end
+end
+qm_psnr = total_psnr / img_num;
+qm_ssim = total_ssim / img_num;
+close(h);
+fprintf('For dataset PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim);
+
+
diff --git a/fixformer.py b/fixformer.py
new file mode 100644
index 0000000..f6df0bd
--- /dev/null
+++ b/fixformer.py
@@ -0,0 +1,683 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+import torch.nn.functional as F
+from einops import rearrange, repeat
+import math
+import random
+import argparse
+import options
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool1d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x): # x: [B, N, C]
+ x = torch.transpose(x, 1, 2) # [B, C, N]
+ b, c, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1)
+ x = x * y.expand_as(x)
+ x = torch.transpose(x, 1, 2) # [B, N, C]
+ return x
+
+
+######## Embedding for q,k,v ########
+
+class LinearProjection(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True, Train=True):
+ super(LinearProjection, self).__init__()
+ inner_dim = dim_head * heads
+ self.heads = heads
+ self.train=Train
+ self.to_q = nn.Linear(dim, inner_dim, bias=bias)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
+ self.dim = dim
+ self.inner_dim = inner_dim
+
+ def forward(self, x, attn_kv=None):
+ B_, N, C = x.shape
+
+ attn_kv = x if attn_kv is None else attn_kv
+ q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
+ kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
+ q = q[0]
+ k, v = kv[0], kv[1]
+ return q, k, v
+
+
+########### feed-forward network #############
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super(Mlp, self).__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+ self.in_features = in_features
+ self.hidden_features = hidden_features
+ self.out_features = out_features
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class LeFF(nn.Module):
+ def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0.):
+ super(LeFF, self).__init__()
+ self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
+ act_layer())
+ self.dwconv = nn.Sequential(
+ nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
+ act_layer())
+ self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ def forward(self, x):
+ # bs x hw x c
+ bs, hw, c = x.size()
+ hh = round(math.sqrt(hw))
+ ww = round(math.sqrt(hw))
+
+ x = self.linear1(x)
+
+ # spatial restore
+ x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=ww)
+ # bs,hidden_dim,32x32
+
+ x = self.dwconv(x)
+
+ # flaten
+ x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=ww)
+
+ x = self.linear2(x)
+
+ return x
+
+
+########### window operation#############
+def window_partition(x, win_size, dilation_rate=1):
+ B, H, W, C = x.shape
+ if dilation_rate != 1:
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
+ assert type(dilation_rate) is int, 'dilation_rate should be a int'
+ x = F.unfold(x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1),
+ stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
+ windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
+ windows = windows.permute(0, 2, 3, 1).contiguous() # B' ,Wh ,Ww ,C
+ else:
+ x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
+ return windows
+
+
+def window_reverse(windows, win_size, H, W, dilation_rate=1):
+ # B' ,Wh ,Ww ,C
+ B = int(windows.shape[0] / (H * W / win_size / win_size))
+ x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
+ if dilation_rate != 1:
+ x = windows.permute(0, 5, 3, 4, 1, 2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
+ x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1),
+ stride=win_size)
+ else:
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+# Downsample Block
+class Downsample(nn.Module):
+ def __init__(self, in_channel, out_channel):
+ super(Downsample, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
+ )
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+ B, L, C = x.shape
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
+ out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
+ return out
+
+
+# Upsample Block
+class Upsample(nn.Module):
+ def __init__(self, in_channel, out_channel):
+ super(Upsample, self).__init__()
+ self.deconv = nn.Sequential(
+ nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2),
+ )
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+ B, L, C = x.shape
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
+ out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
+ return out
+
+
+# Input Projection
+class InputProj(nn.Module):
+ def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None, act_layer=nn.LeakyReLU):
+ super(InputProj, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2),
+ act_layer(inplace=True)
+ )
+ if norm_layer is not None:
+ self.norm = norm_layer(out_channel)
+ else:
+ self.norm = None
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+ x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+
+# Output Projection
+class OutputProj(nn.Module):
+ def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None, act_layer=None):
+ super(OutputProj, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_channel, 3, kernel_size=3, stride=1, padding=1)
+ )
+ if act_layer is not None:
+ self.proj.add_module(act_layer(inplace=True))
+ if norm_layer is not None:
+ self.norm = norm_layer(out_channel)
+ else:
+ self.norm = None
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+
+ def forward(self, x):
+ B, L, C = x.shape
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+ x = x.transpose(1, 2).view(B, C, H, W)
+ x = self.proj(x)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+
+########### StoTransformer #############
+class FixTransformerBlock(nn.Module):
+ def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., proj_drop=0.,drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, stride=1, token_mlp='leff',
+ se_layer=False):
+ super(FixTransformerBlock, self).__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.win_size = win_size
+ self.random_shift = False
+ self.stride=stride
+ self.mlp_ratio = mlp_ratio
+ self.token_mlp = token_mlp
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.shift_size = shift_size
+
+ Twin_size = win_size + 1
+ self.Twin_size = Twin_size
+ self.pad_size = Twin_size // 2 # padding size
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * win_size - 1) * (2 * win_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.win_size) # [0,...,Wh-1]
+ coords_w = torch.arange(self.win_size) # [0,...,Ww-1]
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.win_size- 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.win_size - 1
+ relative_coords[:, :, 0] *= 2 * self.win_size - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.norm1 = norm_layer(dim)
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.proj = nn.Linear(dim, dim)
+ self.se_layer = SELayer(dim) if se_layer else nn.Identity()
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
+ drop=drop) if token_mlp == 'ffn' else LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+
+ def attention(self, q, k, v, attn_mask=None):
+ B_, h, N_, C_ = q.shape
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.win_size * self.win_size, self.win_size * self.win_size, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ ratio = attn.size(-1) // relative_position_bias.size(-1)
+ relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if attn_mask is not None:
+ nW = attn_mask.shape[0] # [nW, N_, N_]
+ mask = repeat(attn_mask, 'nW m n -> nW m (n d)', d=1) # [nW, N_, N_]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N_, N_ * 1) + mask.unsqueeze(1).unsqueeze(
+ 0) # [1, nW, 1, N_, N_]
+ # [B, nW, nh, N_, N_]
+ attn = attn.view(-1, self.num_heads, N_, N_ * 1)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ y = (attn @ v).transpose(1, 2).reshape(B_, N_, h*C_)
+ y = self.proj(y)
+ return y
+
+ def forward(self, x, mask=None):
+ B, L, C = x.shape
+
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+
+ shortcut = x
+ x = self.norm1(x)
+ q = self.to_q(x) #[B, L, C]
+ kv = self.to_kv(x)
+
+ q = rearrange(q, 'b (h w) c -> b h w c', h=H)
+ kv = rearrange(kv, 'b (h w) c -> b h w c', h=H)
+
+ x = x.view(B, H, W, C)
+
+ if mask != None:
+ input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1)
+ input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
+ attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
+ attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ ## Stochastic shift window
+ H_offset = self.shift_size
+ W_offset = self.shift_size
+
+ shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
+
+ if H_offset > 0:
+ h_slices = (slice(0, -self.win_size),
+ slice(-self.win_size, -H_offset),
+ slice(-H_offset, None))
+ else:
+ h_slices = (slice(0, None),)
+ if W_offset > 0:
+ w_slices = (slice(0, -self.win_size),
+ slice(-self.win_size, -W_offset),
+ slice(-W_offset, None))
+ else:
+ w_slices = (slice(0, None),)
+
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ shift_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
+ shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
+ shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(
+ 2) # nW, win_size*win_size, win_size*win_size
+ shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(
+ shift_attn_mask == 0, float(0.0))
+ attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask #[nW, N_,N_]
+
+ # cyclic shift
+ shifted_q = torch.roll(q, shifts=(-H_offset, -W_offset), dims=(1, 2))
+ shifted_kv = torch.roll(kv, shifts=(-H_offset, -W_offset), dims=(1, 2))
+
+ # partition windows
+ q_windows = window_partition(shifted_q, self.win_size) # nW*B, win_size, win_size, C N*C->C
+ q_windows = q_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
+ B_, N_, C_ = q_windows.shape
+ q_windows = q_windows.reshape(B_, N_, self.num_heads, C_ // self.num_heads).permute(0, 2, 1, 3)
+
+ kv_windows = window_partition(shifted_kv, self.win_size) # nW*B, win_size, win_size, 2C
+ kv_windows = kv_windows.view(-1, self.win_size * self.win_size, 2 * C)
+ kv_windows = kv_windows.reshape(B_, N_, 2, self.num_heads, C_ // self.num_heads).permute(2, 0, 3, 1, 4)
+ k_windows, v_windows = kv_windows[0], kv_windows[1]
+
+ attn_windows = self.attention(q_windows, k_windows, v_windows, attn_mask)
+
+ attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
+ x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
+
+ x = torch.roll(x, shifts=(H_offset, W_offset), dims=(1, 2))
+
+ x = x.view(B, H * W, C)
+
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+########### Basic layer of Stoformer ################
+class BasicFixformerLayer(nn.Module):
+ def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
+ token_mlp='leff', se_layer=False):
+
+ super(BasicFixformerLayer, self).__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+ self.hwratio = 1.0
+ self.random_shift = False
+ # build blocks
+ self.blocks = nn.ModuleList([
+ FixTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, win_size=win_size,
+ shift_size = 0 if (i % 2 == 0) else win_size//2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer, token_mlp=token_mlp,
+ se_layer=se_layer)
+ for i in range(depth)])
+
+
+ def forward(self, x, mask=None):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, mask)
+ return x
+
+
+class Fixformer(nn.Module):
+ def __init__(self, img_size=128, in_chans=3,
+ embed_dim=32, depths=[1, 2, 8, 8, 2, 8, 8, 2, 1], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
+ win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, patch_norm=True,
+ use_checkpoint=False, token_mlp='leff', se_layer=False,
+ dowsample=Downsample, upsample=Upsample, **kwargs):
+ super(Fixformer, self).__init__()
+
+ self.num_enc_layers = len(depths) // 2
+ self.num_dec_layers = len(depths) // 2
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.mlp_ratio = mlp_ratio
+ self.mlp = token_mlp
+ self.win_size = win_size
+ self.reso = img_size
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ self.hwratio = 1.0
+ # stochastic depth
+ enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
+ conv_dpr = [drop_path_rate] * depths[4]
+ dec_dpr = enc_dpr[::-1]
+ self.random_shift = False
+ # build layers
+
+ # Input/Output
+ self.input_proj = InputProj(in_channel=in_chans, out_channel=embed_dim, kernel_size=3, stride=1,
+ act_layer=nn.LeakyReLU)
+ self.output_proj = OutputProj(in_channel=2 * embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
+
+ # Encoder
+ self.encoderlayer_0 = BasicFixformerLayer(dim=embed_dim,
+ output_dim=embed_dim,
+ input_resolution=(img_size,
+ img_size),
+ depth=depths[0],
+ num_heads=num_heads[0],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_0 = dowsample(embed_dim, embed_dim * 2)
+
+ self.encoderlayer_1 = BasicFixformerLayer(dim=embed_dim * 2,
+ output_dim=embed_dim * 2,
+ input_resolution=(img_size // 2,
+ img_size // 2),
+ depth=depths[1],
+ num_heads=num_heads[1],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_1 = dowsample(embed_dim * 2, embed_dim * 4)
+
+ self.encoderlayer_2 = BasicFixformerLayer(dim=embed_dim * 4,
+ output_dim=embed_dim * 4,
+ input_resolution=(img_size // (2 ** 2),
+ img_size // (2 ** 2)),
+ depth=depths[2],
+ num_heads=num_heads[2],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_2 = dowsample(embed_dim * 4, embed_dim * 8)
+
+ self.encoderlayer_3 = BasicFixformerLayer(dim=embed_dim * 8,
+ output_dim=embed_dim * 8,
+ input_resolution=(img_size // (2 ** 3),
+ img_size // (2 ** 3)),
+ depth=depths[3],
+ num_heads=num_heads[3],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_3 = dowsample(embed_dim * 8, embed_dim * 16)
+
+ # Bottleneck
+ self.conv = BasicFixformerLayer(dim=embed_dim * 16,
+ output_dim=embed_dim * 16,
+ input_resolution=(img_size // (2 ** 4),
+ img_size // (2 ** 4)),
+ depth=depths[4],
+ num_heads=num_heads[4],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=conv_dpr,
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp, se_layer=se_layer)
+ # Decoder
+ self.upsample_0 = upsample(embed_dim * 16, embed_dim * 8)
+ self.decoderlayer_0 = BasicFixformerLayer(dim=embed_dim * 16,
+ output_dim=embed_dim * 16,
+ input_resolution=(img_size // (2 ** 3),
+ img_size // (2 ** 3)),
+ depth=depths[5],
+ num_heads=num_heads[5],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[:depths[5]],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.upsample_1 = upsample(embed_dim * 16, embed_dim * 4)
+ self.decoderlayer_1 = BasicFixformerLayer(dim=embed_dim * 8,
+ output_dim=embed_dim * 8,
+ input_resolution=(img_size // (2 ** 2),
+ img_size // (2 ** 2)),
+ depth=depths[6],
+ num_heads=num_heads[6],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.upsample_2 = upsample(embed_dim * 8, embed_dim * 2)
+ self.decoderlayer_2 = BasicFixformerLayer(dim=embed_dim * 4,
+ output_dim=embed_dim * 4,
+ input_resolution=(img_size // 2,
+ img_size // 2),
+ depth=depths[7],
+ num_heads=num_heads[7],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.upsample_3 = upsample(embed_dim * 4, embed_dim)
+ self.decoderlayer_3 = BasicFixformerLayer(dim=embed_dim * 2,
+ output_dim=embed_dim * 2,
+ input_resolution=(img_size,
+ img_size),
+ depth=depths[8],
+ num_heads=num_heads[8],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward(self, x, mask=None):
+ # Input Projection
+ y = self.input_proj(x)
+ y = self.pos_drop(y)
+ # Encoder
+ conv0 = self.encoderlayer_0(y, mask=mask) #128x128 32
+ pool0 = self.dowsample_0(conv0)
+ conv1 = self.encoderlayer_1(pool0, mask=mask) #64x64 64
+ pool1 = self.dowsample_1(conv1)
+ conv2 = self.encoderlayer_2(pool1, mask=mask) #32x32 128
+ pool2 = self.dowsample_2(conv2)
+ conv3 = self.encoderlayer_3(pool2, mask=mask) #16x16 256
+ pool3 = self.dowsample_3(conv3)
+
+ # Bottleneck
+ conv4 = self.conv(pool3, mask=mask) #8x8 512
+
+ # Decoder
+ up0 = self.upsample_0(conv4) #16x16 256
+ deconv0 = torch.cat([up0, conv3], -1) #16x16 512
+ deconv0 = self.decoderlayer_0(deconv0, mask=mask) #16x16 512
+
+ up1 = self.upsample_1(deconv0) #32x32 128
+ deconv1 = torch.cat([up1, conv2], -1) #32x32 256
+ deconv1 = self.decoderlayer_1(deconv1, mask=mask) #32x32 256
+
+ up2 = self.upsample_2(deconv1) #64x64 64
+ deconv2 = torch.cat([up2, conv1], -1) #64x64 128
+ deconv2 = self.decoderlayer_2(deconv2, mask=mask) #64x64 128
+
+ up3 = self.upsample_3(deconv2) #128x128 32
+ deconv3 = torch.cat([up3, conv0], -1) #128x128 64
+ deconv3 = self.decoderlayer_3(deconv3, mask=mask)
+
+ # Output Projection
+ y = self.output_proj(deconv3)
+ return x + y
diff --git a/losses.py b/losses.py
new file mode 100644
index 0000000..63725e0
--- /dev/null
+++ b/losses.py
@@ -0,0 +1,49 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def tv_loss(x, beta=0.5, reg_coeff=5):
+ '''Calculates TV loss for an image `x`.
+
+ Args:
+ x: image, torch.Variable of torch.Tensor
+ beta: See https://arxiv.org/abs/1412.0035 (fig. 2) to see effect of `beta`
+ '''
+ dh = torch.pow(x[:, :, :, 1:] - x[:, :, :, :-1], 2)
+ dw = torch.pow(x[:, :, 1:, :] - x[:, :, :-1, :], 2)
+ a, b, c, d = x.shape
+ return reg_coeff * (torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], beta)) / (a * b * c * d))
+
+
+class TVLoss(nn.Module):
+ def __init__(self, tv_loss_weight=1):
+ super(TVLoss, self).__init__()
+ self.tv_loss_weight = tv_loss_weight
+
+ def forward(self, x):
+ batch_size = x.size()[0]
+ h_x = x.size()[2]
+ w_x = x.size()[3]
+ count_h = self.tensor_size(x[:, :, 1:, :])
+ count_w = self.tensor_size(x[:, :, :, 1:])
+ h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
+ w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
+ return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
+
+ @staticmethod
+ def tensor_size(t):
+ return t.size()[1] * t.size()[2] * t.size()[3]
+
+
+class CharbonnierLoss(nn.Module):
+ """Charbonnier Loss"""
+
+ def __init__(self, eps=1e-3):
+ super(CharbonnierLoss, self).__init__()
+ self.eps = eps
+
+ def forward(self, x, y):
+ diff = x - y
+ loss = torch.mean(torch.sqrt((diff * diff) + (self.eps * self.eps)))
+ return loss
diff --git a/options.py b/options.py
new file mode 100644
index 0000000..25725c3
--- /dev/null
+++ b/options.py
@@ -0,0 +1,52 @@
+import os
+import torch
+
+
+class Options():
+ """docstring for Options"""
+
+ def __init__(self):
+ pass
+
+ def init(self, parser):
+ parser.add_argument('--batch_size', type=int, default=8, help='batch size')
+ parser.add_argument('--nepoch', type=int, default=250, help='training epochs')
+ parser.add_argument('--train_workers', type=int, default=4, help='train_dataloader workers')
+ parser.add_argument('--eval_workers', type=int, default=1, help='eval_dataloader workers')
+ parser.add_argument('--optimizer', type=str, default='adam', help='optimizer for training')
+ parser.add_argument('--lr_initial', type=float, default=0.0002, help='initial learning rate')
+ parser.add_argument('--LR_MIN', type=float, default=1e-6)
+ parser.add_argument('--thre', type=int, default=50)
+ parser.add_argument('--weight_decay', type=float, default=0.0, help='weight decay')
+ parser.add_argument('--gpu', type=str, default='0,1', help='GPUs')
+ parser.add_argument('--arch', type=str, default='Stoformer', help='archtechture')
+
+ parser.add_argument('--save_dir', type=str, default='', help='save dir')
+ parser.add_argument('--save_images', action='store_true', default=False)
+ parser.add_argument('--env', type=str, default='_', help='env')
+ parser.add_argument('--checkpoint', type=int, default=50, help='checkpoint')
+
+ parser.add_argument('--norm_layer', type=str, default='nn.LayerNorm', help='normalize layer in transformer')
+ parser.add_argument('--embed_dim', type=int, default=32, help='dim of emdeding features')
+ parser.add_argument('--win_size', type=int, default=8, help='window size of self-attention')
+ parser.add_argument('--token_projection', type=str, default='linear', help='linear/conv token projection')
+ parser.add_argument('--token_mlp', type=str, default='leff', help='ffn/leff token mlp')
+ parser.add_argument('--att_se', action='store_true', default=False, help='se after sa')
+
+ parser.add_argument('--noiselevel', type=float, default=50)
+ parser.add_argument('--use_grad_clip', action='store_true', default=False)
+
+ # args for training
+ parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample')
+ parser.add_argument('--train_dir', type=str, default='', help='dir of train data')
+ parser.add_argument('--val_dir', type=str, default='', help='dir of train data')
+ parser.add_argument('--random_start', type=int, default=0, help='epochs for random shift')
+
+ # args for testing
+ parser.add_argument('--weights', type=str, default='', help='Path to trained weights')
+ parser.add_argument('--test_workers', type=int, default=1, help='number of test works')
+ parser.add_argument('--input_dir', type=str, default='', help='Directory of validation images')
+ parser.add_argument('--result_dir', type=str, default='', help='Directory for results')
+ parser.add_argument('--crop_size', type=int, default=256, help='crop size for testing')
+ parser.add_argument('--overlap_size', type=int, default=30, help='overlap size for testing')
+ return parser
diff --git a/stoformer.py b/stoformer.py
new file mode 100644
index 0000000..82b9b92
--- /dev/null
+++ b/stoformer.py
@@ -0,0 +1,747 @@
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+import torch.nn.functional as F
+from einops import rearrange, repeat
+import math
+import random
+import argparse
+import options
+
+
+class SELayer(nn.Module):
+ def __init__(self, channel, reduction=16):
+ super(SELayer, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool1d(1)
+ self.fc = nn.Sequential(
+ nn.Linear(channel, channel // reduction, bias=False),
+ nn.ReLU(inplace=True),
+ nn.Linear(channel // reduction, channel, bias=False),
+ nn.Sigmoid()
+ )
+
+ def forward(self, x): # x: [B, N, C]
+ x = torch.transpose(x, 1, 2) # [B, C, N]
+ b, c, _ = x.size()
+ y = self.avg_pool(x).view(b, c)
+ y = self.fc(y).view(b, c, 1)
+ x = x * y.expand_as(x)
+ x = torch.transpose(x, 1, 2) # [B, N, C]
+ return x
+
+
+######## Embedding for q,k,v ########
+
+class LinearProjection(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True, Train=True):
+ super(LinearProjection, self).__init__()
+ inner_dim = dim_head * heads
+ self.heads = heads
+ self.train=Train
+ self.to_q = nn.Linear(dim, inner_dim, bias=bias)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
+ self.dim = dim
+ self.inner_dim = inner_dim
+
+ def forward(self, x, attn_kv=None):
+ B_, N, C = x.shape
+
+ attn_kv = x if attn_kv is None else attn_kv
+ q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
+ kv = self.to_kv(attn_kv).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
+ q = q[0]
+ k, v = kv[0], kv[1]
+ return q, k, v
+
+
+########### feed-forward network #############
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super(Mlp, self).__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+ self.in_features = in_features
+ self.hidden_features = hidden_features
+ self.out_features = out_features
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class LeFF(nn.Module):
+ def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0.):
+ super(LeFF, self).__init__()
+ self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
+ act_layer())
+ self.dwconv = nn.Sequential(
+ nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
+ act_layer())
+ self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
+ self.dim = dim
+ self.hidden_dim = hidden_dim
+
+ def forward(self, x):
+ # bs x hw x c
+ bs, hw, c = x.size()
+ hh = round(math.sqrt(hw))
+ ww = round(math.sqrt(hw))
+
+ x = self.linear1(x)
+
+ # spatial restore
+ x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=ww)
+ # bs,hidden_dim,32x32
+
+ x = self.dwconv(x)
+
+ # flaten
+ x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=ww)
+
+ x = self.linear2(x)
+
+ return x
+
+
+########### window operation#############
+def window_partition(x, win_size, dilation_rate=1):
+ B, H, W, C = x.shape
+ if dilation_rate != 1:
+ x = x.permute(0, 3, 1, 2) # B, C, H, W
+ assert type(dilation_rate) is int, 'dilation_rate should be a int'
+ x = F.unfold(x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1),
+ stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
+ windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
+ windows = windows.permute(0, 2, 3, 1).contiguous() # B' ,Wh ,Ww ,C
+ else:
+ x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
+ return windows
+
+
+def window_reverse(windows, win_size, H, W, dilation_rate=1):
+ # B' ,Wh ,Ww ,C
+ B = int(windows.shape[0] / (H * W / win_size / win_size))
+ x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
+ if dilation_rate != 1:
+ x = windows.permute(0, 5, 3, 4, 1, 2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
+ x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1),
+ stride=win_size)
+ else:
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+# Downsample Block
+class Downsample(nn.Module):
+ def __init__(self, in_channel, out_channel):
+ super(Downsample, self).__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
+ )
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+ B, L, C = x.shape
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
+ out = self.conv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
+ return out
+
+
+# Upsample Block
+class Upsample(nn.Module):
+ def __init__(self, in_channel, out_channel):
+ super(Upsample, self).__init__()
+ self.deconv = nn.Sequential(
+ nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2),
+ )
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+ B, L, C = x.shape
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+ x = x.transpose(1, 2).contiguous().view(B, C, H, W)
+ out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
+ return out
+
+
+# Input Projection
+class InputProj(nn.Module):
+ def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None, act_layer=nn.LeakyReLU):
+ super(InputProj, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2),
+ act_layer(inplace=True)
+ )
+ if norm_layer is not None:
+ self.norm = norm_layer(out_channel)
+ else:
+ self.norm = None
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+
+ x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+
+# Output Projection
+class OutputProj(nn.Module):
+ def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None, act_layer=None):
+ super(OutputProj, self).__init__()
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_channel, 3, kernel_size=3, stride=1, padding=1)
+ )
+ if act_layer is not None:
+ self.proj.add_module(act_layer(inplace=True))
+ if norm_layer is not None:
+ self.norm = norm_layer(out_channel)
+ else:
+ self.norm = None
+ self.in_channel = in_channel
+ self.out_channel = out_channel
+
+ def forward(self, x):
+ B, L, C = x.shape
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+ x = x.transpose(1, 2).view(B, C, H, W)
+ x = self.proj(x)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+
+########### StoTransformer #############
+class StoTransformerBlock(nn.Module):
+ def __init__(self, dim, input_resolution, num_heads, win_size=8,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., proj_drop=0.,drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, stride=1, token_mlp='leff',
+ se_layer=False):
+ super(StoTransformerBlock, self).__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.win_size = win_size
+ self.stride=stride
+ self.mlp_ratio = mlp_ratio
+ self.token_mlp = token_mlp
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * win_size - 1) * (2 * win_size - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.win_size) # [0,...,Wh-1]
+ coords_w = torch.arange(self.win_size) # [0,...,Ww-1]
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.win_size- 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.win_size - 1
+ relative_coords[:, :, 0] *= 2 * self.win_size - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+
+ self.norm1 = norm_layer(dim)
+
+ self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+
+ self.softmax = nn.Softmax(dim=-1)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.proj = nn.Linear(dim, dim)
+ self.se_layer = SELayer(dim) if se_layer else nn.Identity()
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
+ drop=drop) if token_mlp == 'ffn' else LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+
+ def attention(self, q, k, v, attn_mask=None):
+ B_, h, N_, C_ = q.shape
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.win_size * self.win_size, self.win_size * self.win_size, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ ratio = attn.size(-1) // relative_position_bias.size(-1)
+ relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if attn_mask is not None:
+ nW = attn_mask.shape[0] # [nW, N_, N_]
+ mask = repeat(attn_mask, 'nW m n -> nW m (n d)', d=1) # [nW, N_, N_]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N_, N_ * 1) + mask.unsqueeze(1).unsqueeze(
+ 0) # [1, nW, 1, N_, N_]
+ # [B, nW, nh, N_, N_]
+ attn = attn.view(-1, self.num_heads, N_, N_ * 1)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ y = (attn @ v).transpose(1, 2).reshape(B_, N_, h*C_)
+ y = self.proj(y)
+ return y
+
+ def forward(self, x, mask=None):
+ B, L, C = x.shape
+
+ H = round(math.sqrt(L))
+ W = round(math.sqrt(L))
+
+ shortcut = x
+ x = self.norm1(x)
+ q = self.to_q(x) #[B, L, C]
+ kv = self.to_kv(x)
+
+ q = rearrange(q, 'b (h w) c -> b h w c', h=H)
+ kv = rearrange(kv, 'b (h w) c -> b h w c', h=H)
+
+ x = x.view(B, H, W, C)
+
+ if self.training:
+ if mask != None:
+ input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1)
+ input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
+ attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
+ attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+
+ ## Stochastic shift window
+ H_offset = random.randint(0, self.win_size - 1)
+ W_offset = random.randint(0, self.win_size - 1)
+
+ shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
+
+ if H_offset > 0:
+ h_slices = (slice(0, -self.win_size),
+ slice(-self.win_size, -H_offset),
+ slice(-H_offset, None))
+ else:
+ h_slices = (slice(0, None),)
+ if W_offset > 0:
+ w_slices = (slice(0, -self.win_size),
+ slice(-self.win_size, -W_offset),
+ slice(-W_offset, None))
+ else:
+ w_slices = (slice(0, None),)
+
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ shift_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
+ shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
+ shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(
+ 2) # nW, win_size*win_size, win_size*win_size
+ shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(
+ shift_attn_mask == 0, float(0.0))
+ attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask #[nW, N_,N_]
+
+ # cyclic shift
+ shifted_q = torch.roll(q, shifts=(-H_offset, -W_offset), dims=(1, 2))
+ shifted_kv = torch.roll(kv, shifts=(-H_offset, -W_offset), dims=(1, 2))
+
+ # partition windows
+ q_windows = window_partition(shifted_q, self.win_size) # nW*B, win_size, win_size, C N*C->C
+ q_windows = q_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
+ B_, N_, C_ = q_windows.shape
+ q_windows = q_windows.reshape(B_, N_, self.num_heads, C_ // self.num_heads).permute(0, 2, 1, 3)
+
+ kv_windows = window_partition(shifted_kv, self.win_size) # nW*B, win_size, win_size, 2C
+ kv_windows = kv_windows.view(-1, self.win_size * self.win_size, 2 * C)
+ kv_windows = kv_windows.reshape(B_, N_, 2, self.num_heads, C_ // self.num_heads).permute(2, 0, 3, 1, 4)
+ k_windows, v_windows = kv_windows[0], kv_windows[1]
+
+ attn_windows = self.attention(q_windows, k_windows, v_windows, attn_mask)
+
+ attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
+ x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
+
+ x = torch.roll(x, shifts=(H_offset, W_offset), dims=(1, 2))
+
+ x = x.view(B, H * W, C)
+ del attn_mask
+
+ else:
+ avg = torch.zeros((B, H*W, C)).cuda()
+ NUM = 0
+ for H_offset in range(self.win_size):
+ for W_offset in range(self.win_size):
+ if mask != None:
+ input_mask = F.interpolate(mask, size=(H, W)).permute(0, 2, 3, 1)
+ input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
+ attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
+ attn_mask = attn_mask.unsqueeze(2) * attn_mask.unsqueeze(
+ 1) # nW, win_size*win_size, win_size*win_size
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0,
+ float(0.0))
+ else:
+ attn_mask = None
+ shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
+
+ if H_offset > 0:
+ h_slices = (slice(0, -self.win_size),
+ slice(-self.win_size, -H_offset),
+ slice(-H_offset, None))
+ else:
+ h_slices = (slice(0, None),)
+
+ if W_offset > 0:
+ w_slices = (slice(0, -self.win_size),
+ slice(-self.win_size, -W_offset),
+ slice(-W_offset, None))
+ else:
+ w_slices = (slice(0, None),)
+
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ shift_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
+ shift_mask_windows = shift_mask_windows.view(-1,
+ self.win_size * self.win_size) # nW, win_size*win_size
+ shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(
+ 2) # nW, win_size*win_size, win_size*win_size
+ shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(
+ shift_attn_mask == 0, float(0.0))
+ attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask # [nW, N_,N_]
+
+ shifted_q = torch.roll(q, shifts=(-H_offset, -W_offset), dims=(1, 2))
+ shifted_kv = torch.roll(kv, shifts=(-H_offset, -W_offset), dims=(1, 2))
+
+ # partition windows
+ q_windows = window_partition(shifted_q, self.win_size) # nW*B, win_size, win_size, C N*C->C
+ q_windows = q_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
+ B_, N_, C_ = q_windows.shape
+ q_windows = q_windows.reshape(B_, N_, self.num_heads, C_ // self.num_heads).permute(0, 2, 1, 3)
+
+ kv_windows = window_partition(shifted_kv, self.win_size) # nW*B, win_size, win_size, 2C
+ kv_windows = kv_windows.view(-1, self.win_size * self.win_size, 2*C)
+ kv_windows = kv_windows.reshape(B_, N_, 2, self.num_heads, C_ // self.num_heads).permute(2, 0, 3, 1, 4)
+ k_windows, v_windows = kv_windows[0], kv_windows[1]
+
+ attn_windows = self.attention(q_windows, k_windows, v_windows, attn_mask)
+
+ attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
+ shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
+ # reverse cyclic shift
+ y = torch.roll(shifted_x, shifts=(H_offset, W_offset), dims=(1, 2))
+
+ y = y.view(B, H * W, C)
+ avg = NUM/(NUM+1)*avg + y/(NUM+1)
+ NUM += 1
+ del attn_mask
+ x = avg
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+########### Basic layer of Stoformer ################
+class BasicStoformerLayer(nn.Module):
+ def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
+ token_mlp='leff', se_layer=False):
+
+ super(BasicStoformerLayer, self).__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+ # build blocks
+ self.blocks = nn.ModuleList([
+ StoTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, win_size=win_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop, attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer, token_mlp=token_mlp,
+ se_layer=se_layer)
+ for i in range(depth)])
+
+
+ def forward(self, x, mask=None):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, mask)
+ return x
+
+
+class Stoformer(nn.Module):
+ def __init__(self, img_size=128, in_chans=3,
+ embed_dim=32, depths=[1, 2, 8, 8, 2, 8, 8, 2, 1], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
+ win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
+ drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm, patch_norm=True,
+ use_checkpoint=False, token_mlp='leff', se_layer=False,
+ dowsample=Downsample, upsample=Upsample, **kwargs):
+ super(Stoformer, self).__init__()
+
+ self.num_enc_layers = len(depths) // 2
+ self.num_dec_layers = len(depths) // 2
+ self.embed_dim = embed_dim
+ self.patch_norm = patch_norm
+ self.mlp_ratio = mlp_ratio
+ self.mlp = token_mlp
+ self.win_size = win_size
+ self.reso = img_size
+ self.pos_drop = nn.Dropout(p=drop_rate)
+ # stochastic depth
+ enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
+ conv_dpr = [drop_path_rate] * depths[4]
+ dec_dpr = enc_dpr[::-1]
+ # build layers
+
+ # Input/Output
+ self.input_proj = InputProj(in_channel=in_chans, out_channel=embed_dim, kernel_size=3, stride=1,
+ act_layer=nn.LeakyReLU)
+ self.output_proj = OutputProj(in_channel=2 * embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
+
+ # Encoder
+ self.encoderlayer_0 = BasicStoformerLayer(dim=embed_dim,
+ output_dim=embed_dim,
+ input_resolution=(img_size,
+ img_size),
+ depth=depths[0],
+ num_heads=num_heads[0],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_0 = dowsample(embed_dim, embed_dim * 2)
+
+ self.encoderlayer_1 = BasicStoformerLayer(dim=embed_dim * 2,
+ output_dim=embed_dim * 2,
+ input_resolution=(img_size // 2,
+ img_size // 2),
+ depth=depths[1],
+ num_heads=num_heads[1],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_1 = dowsample(embed_dim * 2, embed_dim * 4)
+
+ self.encoderlayer_2 = BasicStoformerLayer(dim=embed_dim * 4,
+ output_dim=embed_dim * 4,
+ input_resolution=(img_size // (2 ** 2),
+ img_size // (2 ** 2)),
+ depth=depths[2],
+ num_heads=num_heads[2],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_2 = dowsample(embed_dim * 4, embed_dim * 8)
+
+ self.encoderlayer_3 = BasicStoformerLayer(dim=embed_dim * 8,
+ output_dim=embed_dim * 8,
+ input_resolution=(img_size // (2 ** 3),
+ img_size // (2 ** 3)),
+ depth=depths[3],
+ num_heads=num_heads[3],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+ self.dowsample_3 = dowsample(embed_dim * 8, embed_dim * 16)
+ # Bottleneck
+ self.conv = BasicStoformerLayer(dim=embed_dim * 16,
+ output_dim=embed_dim * 16,
+ input_resolution=(img_size // (2 ** 4),
+ img_size // (2 ** 4)),
+ depth=depths[4],
+ num_heads=num_heads[4],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=conv_dpr,
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp, se_layer=se_layer)
+ # Decoder
+ self.upsample_0 = upsample(embed_dim * 16, embed_dim * 8)
+ self.decoderlayer_0 = BasicStoformerLayer(dim=embed_dim * 16,
+ output_dim=embed_dim * 16,
+ input_resolution=(img_size // (2 ** 3),
+ img_size // (2 ** 3)),
+ depth=depths[5],
+ num_heads=num_heads[5],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[:depths[5]],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.upsample_1 = upsample(embed_dim * 16, embed_dim * 4)
+ self.decoderlayer_1 = BasicStoformerLayer(dim=embed_dim * 8,
+ output_dim=embed_dim * 8,
+ input_resolution=(img_size // (2 ** 2),
+ img_size // (2 ** 2)),
+ depth=depths[6],
+ num_heads=num_heads[6],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.upsample_2 = upsample(embed_dim * 8, embed_dim * 2)
+ self.decoderlayer_2 = BasicStoformerLayer(dim=embed_dim * 4,
+ output_dim=embed_dim * 4,
+ input_resolution=(img_size // 2,
+ img_size // 2),
+ depth=depths[7],
+ num_heads=num_heads[7],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.upsample_3 = upsample(embed_dim * 4, embed_dim)
+ self.decoderlayer_3 = BasicStoformerLayer(dim=embed_dim * 2,
+ output_dim=embed_dim * 2,
+ input_resolution=(img_size,
+ img_size),
+ depth=depths[8],
+ num_heads=num_heads[8],
+ win_size=win_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate, attn_drop=attn_drop_rate,
+ drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
+ norm_layer=norm_layer,
+ use_checkpoint=use_checkpoint,
+ token_mlp=token_mlp,
+ se_layer=se_layer)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward(self, x, mask=None):
+ # Input Projection
+ y = self.input_proj(x)
+ y = self.pos_drop(y)
+ # Encoder
+ conv0 = self.encoderlayer_0(y, mask=mask) #128x128 32
+ pool0 = self.dowsample_0(conv0)
+ conv1 = self.encoderlayer_1(pool0, mask=mask) #64x64 64
+ pool1 = self.dowsample_1(conv1)
+ conv2 = self.encoderlayer_2(pool1, mask=mask) #32x32 128
+ pool2 = self.dowsample_2(conv2)
+ conv3 = self.encoderlayer_3(pool2, mask=mask) #16x16 256
+ pool3 = self.dowsample_3(conv3)
+
+ # Bottleneck
+ conv4 = self.conv(pool3, mask=mask) #8x8 512
+
+ # Decoder
+ up0 = self.upsample_0(conv4) #16x16 256
+ deconv0 = torch.cat([up0, conv3], -1) #16x16 512
+ deconv0 = self.decoderlayer_0(deconv0, mask=mask) #16x16 512
+
+ up1 = self.upsample_1(deconv0) #32x32 128
+ deconv1 = torch.cat([up1, conv2], -1) #32x32 256
+ deconv1 = self.decoderlayer_1(deconv1, mask=mask) #32x32 256
+
+ up2 = self.upsample_2(deconv1) #64x64 64
+ deconv2 = torch.cat([up2, conv1], -1) #64x64 128
+ deconv2 = self.decoderlayer_2(deconv2, mask=mask) #64x64 128
+
+ up3 = self.upsample_3(deconv2) #128x128 32
+ deconv3 = torch.cat([up3, conv0], -1) #128x128 64
+ deconv3 = self.decoderlayer_3(deconv3, mask=mask)
+
+ # Output Projection
+ y = self.output_proj(deconv3)
+ return x + y
diff --git a/test_Deblur.py b/test_Deblur.py
new file mode 100644
index 0000000..d2d8ee7
--- /dev/null
+++ b/test_Deblur.py
@@ -0,0 +1,44 @@
+import numpy as np
+import os, sys, math
+import argparse
+from tqdm import tqdm
+import torch
+from torch.utils.data import DataLoader
+from utils.loader import get_test_data
+from utils.image_utils import splitimage, mergeimage
+import utils
+import options
+
+args = options.Options().init(argparse.ArgumentParser(description='image debluring')).parse_args()
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
+
+
+if __name__ == '__main__':
+ utils.mkdir(args.result_dir)
+ model_restoration= utils.get_arch(args)
+ model_restoration = torch.nn.DataParallel(model_restoration)
+ utils.load_checkpoint(model_restoration, args.weights)
+ print("===>Testing using weights: ", args.weights)
+ model_restoration.cuda()
+ model_restoration.eval()
+ inp_dir = args.input_dir
+ test_dataset = get_test_data(inp_dir)
+ test_loader = DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False,
+ pin_memory=True, drop_last=False, num_workers=args.test_workers)
+ result_dir = args.result_dir
+ os.makedirs(result_dir, exist_ok=True)
+
+ with torch.no_grad():
+ for input_, file_ in tqdm(test_loader):
+ input_ = input_.cuda()
+ B, C, H, W = input_.shape
+ split_data, starts = splitimage(input_, crop_size=args.crop_size, overlap_size=args.overlap_size)
+ for i, data in enumerate(split_data):
+ split_data[i] = model_restoration(data).cpu()
+ restored = mergeimage(split_data, starts, crop_size = args.crop_size, resolution=(B, C, H, W))
+ restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy()
+ for j in range(B):
+ restored_ = restored[j]
+ save_file = os.path.join(result_dir, file_[j])
+ utils.save_img(save_file, np.uint8(np.around(restored_*255)))
diff --git a/train_Deblur.py b/train_Deblur.py
new file mode 100644
index 0000000..63b51c3
--- /dev/null
+++ b/train_Deblur.py
@@ -0,0 +1,190 @@
+import os
+import sys
+
+# add dir
+dir_name = os.path.dirname(os.path.abspath(__file__))
+sys.path.append(os.path.join(dir_name, './auxiliary/'))
+print(dir_name)
+
+import argparse
+import options
+
+######### parser ###########
+opt = options.Options().init(argparse.ArgumentParser(description='image debluring')).parse_args()
+print(opt)
+
+import utils
+
+######### Set GPUs ###########
+os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
+import torch
+
+torch.backends.cudnn.benchmark = True
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import random
+import time
+import numpy as np
+from einops import rearrange, repeat
+import datetime
+
+from losses import CharbonnierLoss
+
+from warmup_scheduler import CosineAnnealingWithRestartsLR, CosineAnnealingRestartCyclicLR
+from timm.utils import NativeScaler
+
+from utils.loader import get_deblur_training_data
+
+def print_network(net):
+ num_params = 0
+ for param in net.parameters():
+ num_params += param.numel()
+ print('Total number of parameters: %d' % num_params)
+
+######### Logs dir ###########
+log_dir = os.path.join(dir_name, 'log', opt.arch + opt.env)
+if not os.path.exists(log_dir):
+ os.makedirs(log_dir)
+logname = os.path.join(log_dir, datetime.datetime.now().isoformat() + '.txt')
+print("Now time is : ", datetime.datetime.now().isoformat())
+result_dir = os.path.join(log_dir, 'results')
+model_dir = opt.save_dir
+utils.mkdir(result_dir)
+utils.mkdir(model_dir)
+
+# ######### Set Seeds ###########
+random.seed(1234)
+np.random.seed(1234)
+torch.manual_seed(1234)
+torch.cuda.manual_seed_all(1234)
+
+######### Model ###########
+model_restoration = utils.get_arch(opt)
+print_network(model_restoration)
+
+with open(logname, 'a') as f:
+ f.write(str(opt) + '\n')
+ f.write(str(model_restoration) + '\n')
+
+######### Optimizer ###########
+start_epoch = 1
+if opt.optimizer.lower() == 'adam':
+ optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=opt.weight_decay)
+elif opt.optimizer.lower() == 'adamw':
+ optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=opt.weight_decay)
+else:
+ raise Exception("Error optimizer...")
+
+scheduler = CosineAnnealingRestartCyclicLR(optimizer, periods=[opt.thre, opt.nepoch-opt.thre], restart_weights=[1, 1] ,eta_mins=[opt.lr_initial, opt.LR_MIN])
+
+######### DataParallel ###########
+model_restoration = torch.nn.DataParallel(model_restoration)
+model_restoration.cuda()
+
+######### Loss ###########
+criterion = CharbonnierLoss().cuda()
+
+######### DataLoader ###########
+print('===> Loading datasets')
+train_dataset = get_deblur_training_data(opt.train_dir, opt.train_ps)
+
+train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True,
+ num_workers=opt.train_workers, pin_memory=True, drop_last=False)
+val_dataset = get_deblur_training_data(opt.val_dir, opt.train_ps)
+val_loader = DataLoader(dataset=val_dataset, batch_size=opt.batch_size, shuffle=False,
+ num_workers=opt.eval_workers, pin_memory=True, drop_last=False)
+len_trainset = train_dataset.__len__()
+len_valset = val_dataset.__len__()
+
+print("Sizeof training set: ", len_trainset, ", sizeof validation set: ", len_valset)
+
+######### train ###########
+print('===> Start Epoch {} End Epoch {}'.format(start_epoch, opt.nepoch))
+best_psnr = 0
+best_epoch = 0
+best_iter = 0
+
+loss_scaler = NativeScaler()
+torch.cuda.empty_cache()
+global_step = 0
+
+eval_now = len(train_loader)
+print("Eval Freq: ", eval_now)
+for epoch in range(start_epoch, opt.nepoch + 1):
+ epoch_start_time = time.time()
+ epoch_loss = 0
+
+ model_restoration.train()
+ for i, data in enumerate(train_loader, 0):
+ global_step += 1
+ model_restoration.zero_grad()
+ optimizer.zero_grad()
+
+ target = data[0].cuda()
+ input_ = data[1].cuda()
+
+ restored = model_restoration(input_)
+ restored = torch.clamp(restored, 0, 1)
+ loss = criterion(restored, target)
+
+ loss.backward()
+ if opt.use_grad_clip:
+ torch.nn.utils.clip_grad_norm_(model_restoration.parameters(), 0.01)
+ optimizer.step()
+ epoch_loss += loss.item()
+
+ if global_step % 50 == 0:
+ print("Epoch: %d, LearningRate: %.6f, global step: %d, loss: %.4f, time:%.4f" %(epoch, scheduler.get_lr()[0],
+ global_step, epoch_loss, time.time() - epoch_start_time))
+ epoch_loss = 0.0
+ epoch_start_time = time.time()
+
+ #### Evaluation ####
+ if (i + 1) % eval_now == 0 and i > 0:
+ with torch.no_grad():
+ model_restoration.eval()
+ psnr_val_rgb = []
+ for ii, data_val in enumerate((val_loader), 0):
+ target = data_val[0].cuda()
+ input_ = data_val[1].cuda()
+ with torch.cuda.amp.autocast():
+ restored = model_restoration(input_)
+ restored = torch.clamp(restored, 0, 1)
+ psnr_val_rgb.append(utils.batch_PSNR(restored, target, False).item())
+ psnr_val_rgb = sum(psnr_val_rgb) / len_valset
+
+ if psnr_val_rgb > best_psnr:
+ best_psnr = psnr_val_rgb
+ best_epoch = epoch
+ best_iter = i
+ torch.save({'epoch': epoch,
+ 'state_dict': model_restoration.state_dict(),
+ 'optimizer': optimizer.state_dict()
+ }, os.path.join(model_dir, "model_best.pth"))
+
+ print(
+ "[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " % (
+ epoch, i, psnr_val_rgb, best_epoch, best_iter, best_psnr))
+ with open(logname, 'a') as f:
+ f.write(
+ "[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " \
+ % (epoch, i, psnr_val_rgb, best_epoch, best_iter, best_psnr) + '\n')
+ model_restoration.train()
+ torch.cuda.empty_cache()
+ scheduler.step()
+
+ torch.save({'epoch': epoch,
+ 'state_dict': model_restoration.state_dict(),
+ 'optimizer': optimizer.state_dict()
+ }, os.path.join(model_dir, "model_latest.pth"))
+
+ if epoch % opt.checkpoint == 0:
+ torch.save({'epoch': epoch,
+ 'state_dict': model_restoration.state_dict(),
+ 'optimizer': optimizer.state_dict()
+ }, os.path.join(model_dir, "model_epoch_{}.pth".format(epoch)))
+print("Now time is : ", datetime.datetime.now().isoformat())
diff --git a/utils/__init__.py b/utils/__init__.py
new file mode 100644
index 0000000..2452954
--- /dev/null
+++ b/utils/__init__.py
@@ -0,0 +1,4 @@
+from .dir_utils import *
+from .dataset_utils import *
+from .image_utils import *
+from .model_utils import *
diff --git a/utils/antialias.py b/utils/antialias.py
new file mode 100644
index 0000000..4a45a6c
--- /dev/null
+++ b/utils/antialias.py
@@ -0,0 +1,125 @@
+# Copyright (c) 2019, Adobe Inc. All rights reserved.
+#
+# This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
+# 4.0 International Public License. To view a copy of this license, visit
+# https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
+
+
+
+######## https://github.com/adobe/antialiased-cnns/blob/master/models_lpf/__init__.py
+
+
+
+import torch
+import torch.nn.parallel
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+class Downsample(nn.Module):
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
+ super(Downsample, self).__init__()
+ self.filt_size = filt_size
+ self.pad_off = pad_off
+ self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
+ self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
+ self.stride = stride
+ self.off = int((self.stride-1)/2.)
+ self.channels = channels
+
+ # print('Filter size [%i]'%filt_size)
+ if(self.filt_size==1):
+ a = np.array([1.,])
+ elif(self.filt_size==2):
+ a = np.array([1., 1.])
+ elif(self.filt_size==3):
+ a = np.array([1., 2., 1.])
+ elif(self.filt_size==4):
+ a = np.array([1., 3., 3., 1.])
+ elif(self.filt_size==5):
+ a = np.array([1., 4., 6., 4., 1.])
+ elif(self.filt_size==6):
+ a = np.array([1., 5., 10., 10., 5., 1.])
+ elif(self.filt_size==7):
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
+
+ filt = torch.Tensor(a[:,None]*a[None,:])
+ filt = filt/torch.sum(filt)
+ self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
+
+ self.pad = get_pad_layer(pad_type)(self.pad_sizes)
+
+ def forward(self, inp):
+ if(self.filt_size==1):
+ if(self.pad_off==0):
+ return inp[:,:,::self.stride,::self.stride]
+ else:
+ return self.pad(inp)[:,:,::self.stride,::self.stride]
+ else:
+ return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
+
+def get_pad_layer(pad_type):
+ if(pad_type in ['refl','reflect']):
+ PadLayer = nn.ReflectionPad2d
+ elif(pad_type in ['repl','replicate']):
+ PadLayer = nn.ReplicationPad2d
+ elif(pad_type=='zero'):
+ PadLayer = nn.ZeroPad2d
+ else:
+ print('Pad type [%s] not recognized'%pad_type)
+ return PadLayer
+
+
+class Downsample1D(nn.Module):
+ def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
+ super(Downsample1D, self).__init__()
+ self.filt_size = filt_size
+ self.pad_off = pad_off
+ self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
+ self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
+ self.stride = stride
+ self.off = int((self.stride - 1) / 2.)
+ self.channels = channels
+
+ # print('Filter size [%i]' % filt_size)
+ if(self.filt_size == 1):
+ a = np.array([1., ])
+ elif(self.filt_size == 2):
+ a = np.array([1., 1.])
+ elif(self.filt_size == 3):
+ a = np.array([1., 2., 1.])
+ elif(self.filt_size == 4):
+ a = np.array([1., 3., 3., 1.])
+ elif(self.filt_size == 5):
+ a = np.array([1., 4., 6., 4., 1.])
+ elif(self.filt_size == 6):
+ a = np.array([1., 5., 10., 10., 5., 1.])
+ elif(self.filt_size == 7):
+ a = np.array([1., 6., 15., 20., 15., 6., 1.])
+
+ filt = torch.Tensor(a)
+ filt = filt / torch.sum(filt)
+ self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
+
+ self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
+
+ def forward(self, inp):
+ if(self.filt_size == 1):
+ if(self.pad_off == 0):
+ return inp[:, :, ::self.stride]
+ else:
+ return self.pad(inp)[:, :, ::self.stride]
+ else:
+ return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
+
+
+def get_pad_layer_1d(pad_type):
+ if(pad_type in ['refl', 'reflect']):
+ PadLayer = nn.ReflectionPad1d
+ elif(pad_type in ['repl', 'replicate']):
+ PadLayer = nn.ReplicationPad1d
+ elif(pad_type == 'zero'):
+ PadLayer = nn.ZeroPad1d
+ else:
+ print('Pad type [%s] not recognized' % pad_type)
+ return PadLayer
\ No newline at end of file
diff --git a/utils/bundle_submissions.py b/utils/bundle_submissions.py
new file mode 100644
index 0000000..24e7047
--- /dev/null
+++ b/utils/bundle_submissions.py
@@ -0,0 +1,104 @@
+ # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de)
+
+ # This file is part of the implementation as described in the CVPR 2017 paper:
+ # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
+ # Please see the file LICENSE.txt for the license governing this code.
+
+
+import numpy as np
+import scipy.io as sio
+import os
+import h5py
+
+def bundle_submissions_raw(submission_folder,session):
+ '''
+ Bundles submission data for raw denoising
+ submission_folder Folder where denoised images reside
+ Output is written to /bundled/. Please submit
+ the content of this folder.
+ '''
+
+ out_folder = os.path.join(submission_folder, session)
+ # out_folder = os.path.join(submission_folder, "bundled/")
+ try:
+ os.mkdir(out_folder)
+ except:pass
+
+ israw = True
+ eval_version="1.0"
+
+ for i in range(50):
+ Idenoised = np.zeros((20,), dtype=np.object)
+ for bb in range(20):
+ filename = '%04d_%02d.mat'%(i+1,bb+1)
+ s = sio.loadmat(os.path.join(submission_folder,filename))
+ Idenoised_crop = s["Idenoised_crop"]
+ Idenoised[bb] = Idenoised_crop
+ filename = '%04d.mat'%(i+1)
+ sio.savemat(os.path.join(out_folder, filename),
+ {"Idenoised": Idenoised,
+ "israw": israw,
+ "eval_version": eval_version},
+ )
+
+def bundle_submissions_srgb(submission_folder,session):
+ '''
+ Bundles submission data for sRGB denoising
+
+ submission_folder Folder where denoised images reside
+ Output is written to /bundled/. Please submit
+ the content of this folder.
+ '''
+ out_folder = os.path.join(submission_folder, session)
+ # out_folder = os.path.join(submission_folder, "bundled/")
+ try:
+ os.mkdir(out_folder)
+ except:pass
+ israw = False
+ eval_version="1.0"
+
+ for i in range(50):
+ Idenoised = np.zeros((20,), dtype=np.object)
+ for bb in range(20):
+ filename = '%04d_%02d.mat'%(i+1,bb+1)
+ s = sio.loadmat(os.path.join(submission_folder,filename))
+ Idenoised_crop = s["Idenoised_crop"]
+ Idenoised[bb] = Idenoised_crop
+ filename = '%04d.mat'%(i+1)
+ sio.savemat(os.path.join(out_folder, filename),
+ {"Idenoised": Idenoised,
+ "israw": israw,
+ "eval_version": eval_version},
+ )
+
+
+
+def bundle_submissions_srgb_v1(submission_folder,session):
+ '''
+ Bundles submission data for sRGB denoising
+
+ submission_folder Folder where denoised images reside
+ Output is written to /bundled/. Please submit
+ the content of this folder.
+ '''
+ out_folder = os.path.join(submission_folder, session)
+ # out_folder = os.path.join(submission_folder, "bundled/")
+ try:
+ os.mkdir(out_folder)
+ except:pass
+ israw = False
+ eval_version="1.0"
+
+ for i in range(50):
+ Idenoised = np.zeros((20,), dtype=np.object)
+ for bb in range(20):
+ filename = '%04d_%d.mat'%(i+1,bb+1)
+ s = sio.loadmat(os.path.join(submission_folder,filename))
+ Idenoised_crop = s["Idenoised_crop"]
+ Idenoised[bb] = Idenoised_crop
+ filename = '%04d.mat'%(i+1)
+ sio.savemat(os.path.join(out_folder, filename),
+ {"Idenoised": Idenoised,
+ "israw": israw,
+ "eval_version": eval_version},
+ )
\ No newline at end of file
diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py
new file mode 100644
index 0000000..3bfbb63
--- /dev/null
+++ b/utils/dataset_utils.py
@@ -0,0 +1,49 @@
+import torch
+import os
+
+### rotate and flip
+class Augment_RGB_torch:
+ def __init__(self):
+ pass
+ def transform0(self, torch_tensor):
+ return torch_tensor
+ def transform1(self, torch_tensor):
+ torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2])
+ return torch_tensor
+ def transform2(self, torch_tensor):
+ torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2])
+ return torch_tensor
+ def transform3(self, torch_tensor):
+ torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2])
+ return torch_tensor
+ def transform4(self, torch_tensor):
+ torch_tensor = torch_tensor.flip(-2)
+ return torch_tensor
+ def transform5(self, torch_tensor):
+ torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2)
+ return torch_tensor
+ def transform6(self, torch_tensor):
+ torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2)
+ return torch_tensor
+ def transform7(self, torch_tensor):
+ torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2)
+ return torch_tensor
+
+
+### mix two images
+class MixUp_AUG:
+ def __init__(self):
+ self.dist = torch.distributions.beta.Beta(torch.tensor([1.2]), torch.tensor([1.2]))
+
+ def aug(self, rgb_gt, rgb_noisy):
+ bs = rgb_gt.size(0)
+ indices = torch.randperm(bs)
+ rgb_gt2 = rgb_gt[indices]
+ rgb_noisy2 = rgb_noisy[indices]
+
+ lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
+
+ rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
+ rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
+
+ return rgb_gt, rgb_noisy
diff --git a/utils/dir_utils.py b/utils/dir_utils.py
new file mode 100644
index 0000000..2afff38
--- /dev/null
+++ b/utils/dir_utils.py
@@ -0,0 +1,18 @@
+import os
+#from natsort import natsorted
+from glob import glob
+
+def mkdirs(paths):
+ if isinstance(paths, list) and not isinstance(paths, str):
+ for path in paths:
+ mkdir(path)
+ else:
+ mkdir(paths)
+
+def mkdir(path):
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+def get_last_path(path, session):
+ x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
+ return x
\ No newline at end of file
diff --git a/utils/image_utils.py b/utils/image_utils.py
new file mode 100644
index 0000000..b1877dc
--- /dev/null
+++ b/utils/image_utils.py
@@ -0,0 +1,121 @@
+import torch
+import numpy as np
+import pickle
+import cv2
+import math
+
+def is_numpy_file(filename):
+ return any(filename.endswith(extension) for extension in [".npy"])
+
+def is_image_file(filename):
+ return any(filename.endswith(extension) for extension in [".jpg"])
+
+def is_png_file(filename):
+ return any(filename.endswith(extension) for extension in [".png", '.jpg'])
+
+def is_pkl_file(filename):
+ return any(filename.endswith(extension) for extension in [".pkl"])
+
+def load_pkl(filename_):
+ with open(filename_, 'rb') as f:
+ ret_dict = pickle.load(f)
+ return ret_dict
+
+def save_dict(dict_, filename_):
+ with open(filename_, 'wb') as f:
+ pickle.dump(dict_, f)
+
+def load_npy(filepath):
+ img = np.load(filepath)
+ return img
+
+def load_img(filepath):
+ img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
+ img = img.astype(np.float32)
+ img = img/255.
+ return img
+
+def load_img2(filepath):
+ img = cv2.cvtColor(cv2.imread(filepath, -1), cv2.COLOR_BGR2RGB)
+ img = img.astype(np.float32)
+ img = img/65535.
+ return img
+
+def load_gray_img(filepath):
+ img = np.expand_dims(cv2.imread(filepath, cv2.IMREAD_GRAYSCALE), axis=2)
+ img = img.astype(np.float32)
+ img = img/255.
+ return img
+
+def save_gray_img(filepath, img):
+ cv2.imwrite(filepath, img)
+
+def save_img(filepath, img):
+ cv2.imwrite(filepath, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
+
+def myPSNR(tar_img, prd_img):
+ imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
+ rmse = (imdff**2).mean().sqrt()
+ ps = 20*torch.log10(1/rmse)
+ return ps
+
+
+def batch_PSNR(img1, img2, average=True):
+ PSNR = []
+ for im1, im2 in zip(img1, img2):
+ psnr = myPSNR(im1, im2)
+ PSNR.append(psnr)
+ return sum(PSNR)/len(PSNR) if average else sum(PSNR)
+
+def YPSNR(tar_img, prd_img):
+ tar_img = cv2.cvtColor(tar_img, cv2.COLOR_RGB2YUV)[0]/255.0
+ prd_img = cv2.cvtColor(prd_img, cv2.COLOR_RGB2YUV)[0]/255.0
+ #tar_img = tar_img/255.0
+ #prd_img = prd_img/255.5
+ #imdff = np.clip(prd_img,0,1) - np.clip(tar_img,0,1)
+ imdff = prd_img - tar_img
+ rmse = np.sqrt(np.mean(imdff**2))
+ ps = 20*np.log10(1/rmse)
+ return ps
+
+def splitimage(imgtensor, crop_size=128, overlap_size=64):
+ _, C, H, W = imgtensor.shape
+ hstarts = [x for x in range(0, H, crop_size - overlap_size)]
+ while hstarts and hstarts[-1] + crop_size >= H:
+ hstarts.pop()
+ hstarts.append(H - crop_size)
+ wstarts = [x for x in range(0, W, crop_size - overlap_size)]
+ while wstarts and wstarts[-1] + crop_size >= W:
+ wstarts.pop()
+ wstarts.append(W - crop_size)
+ starts = []
+ split_data = []
+ for hs in hstarts:
+ for ws in wstarts:
+ cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
+ starts.append((hs, ws))
+ split_data.append(cimgdata)
+ return split_data, starts
+
+def get_scoremap(H, W, C, B=1, is_mean=True):
+ center_h = H / 2
+ center_w = W / 2
+
+ score = torch.ones((B, C, H, W))
+ if not is_mean:
+ for h in range(H):
+ for w in range(W):
+ score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
+ return score
+
+def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
+ B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
+ tot_score = torch.zeros((B, C, H, W))
+ merge_img = torch.zeros((B, C, H, W))
+ scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
+ for simg, cstart in zip(split_data, starts):
+ hs, ws = cstart
+ merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
+ tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
+ merge_img = merge_img / tot_score
+ return merge_img
\ No newline at end of file
diff --git a/utils/loader.py b/utils/loader.py
new file mode 100644
index 0000000..b6ae123
--- /dev/null
+++ b/utils/loader.py
@@ -0,0 +1,10 @@
+import os
+from dataset import DataLoaderTrainGoPro, DataLoaderTest
+
+def get_deblur_training_data(rgb_dir, patchsize):
+ assert os.path.exists(rgb_dir)
+ return DataLoaderTrainGoPro(rgb_dir, patchsize, None)
+
+def get_test_data(input_dir):
+ return DataLoaderTest(input_dir)
+
diff --git a/utils/model_utils.py b/utils/model_utils.py
new file mode 100644
index 0000000..2020ea6
--- /dev/null
+++ b/utils/model_utils.py
@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+import os
+from collections import OrderedDict
+
+def freeze(model):
+ for p in model.parameters():
+ p.requires_grad=False
+
+def unfreeze(model):
+ for p in model.parameters():
+ p.requires_grad=True
+
+def is_frozen(model):
+ x = [p.requires_grad for p in model.parameters()]
+ return not all(x)
+
+def save_checkpoint(model_dir, state, session):
+ epoch = state['epoch']
+ model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
+ torch.save(state, model_out_path)
+
+def load_checkpoint(model, weights):
+ checkpoint = torch.load(weights)
+ try:
+ model.load_state_dict(checkpoint["state_dict"])
+ except:
+ state_dict = checkpoint["state_dict"]
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:] if 'module.' in k else k
+ new_state_dict[name] = v
+ model.load_state_dict(new_state_dict)
+
+
+def load_checkpoint_multigpu(model, weights):
+ checkpoint = torch.load(weights)
+ state_dict = checkpoint["state_dict"]
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ name = k[7:]
+ new_state_dict[name] = v
+ model.load_state_dict(new_state_dict)
+
+def load_start_epoch(weights):
+ checkpoint = torch.load(weights)
+ epoch = checkpoint["epoch"]
+ return epoch
+
+def load_optim(optimizer, weights):
+ checkpoint = torch.load(weights)
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ for p in optimizer.param_groups: lr = p['lr']
+ return lr
+
+def get_arch(opt):
+ from stoformer import Stoformer
+ from fixformer import Fixformer
+ arch = opt.arch
+ print('You choose '+arch+'...')
+ if arch.lower() == "stoformer":
+ model_restoration = Stoformer(img_size=opt.train_ps, embed_dim=opt.embed_dim, win_size=opt.win_size,
+ token_mlp=opt.token_mlp)
+ elif arch.lower() == "fixformer":
+ model_restoration = Fixformer(img_size=opt.train_ps, embed_dim=opt.embed_dim, win_size=opt.win_size,
+ token_mlp=opt.token_mlp)
+ else:
+ raise Exception("Arch error!")
+
+ return model_restoration
\ No newline at end of file
diff --git a/warmup_scheduler/__init__.py b/warmup_scheduler/__init__.py
new file mode 100644
index 0000000..a0051fb
--- /dev/null
+++ b/warmup_scheduler/__init__.py
@@ -0,0 +1,2 @@
+
+from warmup_scheduler.scheduler import GradualWarmupScheduler, CosineAnnealingWithRestartsLR, CosineAnnealingRestartCyclicLR
diff --git a/warmup_scheduler/run.py b/warmup_scheduler/run.py
new file mode 100644
index 0000000..a1dbf3d
--- /dev/null
+++ b/warmup_scheduler/run.py
@@ -0,0 +1,24 @@
+import torch
+from torch.optim.lr_scheduler import StepLR, ExponentialLR
+from torch.optim.sgd import SGD
+
+from warmup_scheduler import GradualWarmupScheduler
+
+
+if __name__ == '__main__':
+ model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
+ optim = SGD(model, 0.1)
+
+ # scheduler_warmup is chained with schduler_steplr
+ scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
+ scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
+
+ # this zero gradient update is needed to avoid a warning message, issue #8.
+ optim.zero_grad()
+ optim.step()
+
+ for epoch in range(1, 20):
+ scheduler_warmup.step(epoch)
+ print(epoch, optim.param_groups[0]['lr'])
+
+ optim.step() # backward pass (update network)
diff --git a/warmup_scheduler/scheduler.py b/warmup_scheduler/scheduler.py
new file mode 100644
index 0000000..cde5c2a
--- /dev/null
+++ b/warmup_scheduler/scheduler.py
@@ -0,0 +1,179 @@
+from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+import math
+
+class GradualWarmupScheduler(_LRScheduler):
+ """ Gradually warm-up(increasing) learning rate in optimizer.
+ Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
+ total_epoch: target learning rate is reached at total_epoch, gradually
+ after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
+ """
+
+ def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
+ self.multiplier = multiplier
+ if self.multiplier < 1.:
+ raise ValueError('multiplier should be greater thant or equal to 1.')
+ self.total_epoch = total_epoch
+ self.after_scheduler = after_scheduler
+ self.finished = False
+ super(GradualWarmupScheduler, self).__init__(optimizer)
+
+ def get_lr(self):
+ if self.last_epoch > self.total_epoch:
+ if self.after_scheduler:
+ if not self.finished:
+ self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
+ self.finished = True
+ return self.after_scheduler.get_lr()
+ return [base_lr * self.multiplier for base_lr in self.base_lrs]
+
+ if self.multiplier == 1.0:
+ return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
+ else:
+ return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
+
+ def step_ReduceLROnPlateau(self, metrics, epoch=None):
+ if epoch is None:
+ epoch = self.last_epoch + 1
+ self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
+ if self.last_epoch <= self.total_epoch:
+ warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
+ for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
+ param_group['lr'] = lr
+ else:
+ if epoch is None:
+ self.after_scheduler.step(metrics, None)
+ else:
+ self.after_scheduler.step(metrics, epoch - self.total_epoch)
+
+ def step(self, epoch=None, metrics=None):
+ if type(self.after_scheduler) != ReduceLROnPlateau:
+ if self.finished and self.after_scheduler:
+ if epoch is None:
+ self.after_scheduler.step(None)
+ else:
+ self.after_scheduler.step(epoch - self.total_epoch)
+ else:
+ return super(GradualWarmupScheduler, self).step(epoch)
+ else:
+ self.step_ReduceLROnPlateau(metrics, epoch)
+
+
+class CosineAnnealingWithRestartsLR(_LRScheduler):
+ r"""Set the learning rate of each parameter group using a cosine annealing
+ schedule, where :math:`\eta_{max}` is set to the initial lr and
+ :math:`T_{cur}` is the number of epochs since the last restart in SGDR:
+ .. math::
+ \eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
+ \cos(\frac{T_{cur}}{T_{max}}\pi))
+ When last_epoch=-1, sets initial lr as lr.
+ It has been proposed in
+ `SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
+ implements the cosine annealing part of SGDR, and not the restarts.
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ T_max (int): Maximum number of iterations.
+ eta_min (float): Minimum learning rate. Default: 0.
+ last_epoch (int): The index of last epoch. Default: -1.
+ .. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
+ https://arxiv.org/abs/1608.03983
+ """
+
+ def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, T_mult=1):
+ self.T_max = T_max
+ self.T_mult = T_mult
+ self.next_restart = T_max
+ self.eta_min = eta_min
+ self.restarts = 0
+ self.last_restart = 0
+ super(
+ CosineAnnealingWithRestartsLR,
+ self).__init__(
+ optimizer,
+ last_epoch)
+
+ def get_lr(self):
+ self.Tcur = self.last_epoch - self.last_restart
+ if self.Tcur >= self.next_restart:
+ self.next_restart *= self.T_mult
+ self.last_restart = self.last_epoch
+
+ return [(self.eta_min +
+ (base_lr -
+ self.eta_min) *
+ (1 +
+ math.cos(math.pi *
+ self.Tcur /
+ self.next_restart)) /
+ 2) for base_lr in self.base_lrs]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartCyclicLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The mimimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+
+ def __init__(self,
+ optimizer,
+ periods,
+ restart_weights=(1,),
+ eta_mins=(0,),
+ last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_mins = eta_mins
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+ super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch,
+ self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+ eta_min = self.eta_mins[idx]
+
+ return [
+ eta_min + current_weight * 0.5 * (base_lr - eta_min) *
+ (1 + math.cos(math.pi * (
+ (self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
\ No newline at end of file