Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit 4bc610a

Browse files
authored
Merge pull request #479 from StijnvWijn/475-add-different-downsampling-methods-to-PatchGAN-discriminator
475 add different downsampling methods to patch gan discriminator
2 parents ef6b7e6 + 2045a00 commit 4bc610a

File tree

3 files changed

+138
-21
lines changed

3 files changed

+138
-21
lines changed

generative/networks/nets/patchgan_discriminator.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
import torch.nn as nn
1919
from monai.networks.blocks import Convolution
20-
from monai.networks.layers import Act
20+
from monai.networks.layers import Act, get_pool_layer
2121

2222

2323
class MultiScalePatchDiscriminator(nn.Sequential):
@@ -38,6 +38,8 @@ class MultiScalePatchDiscriminator(nn.Sequential):
3838
spatial_dims: number of spatial dimensions (1D, 2D etc.)
3939
num_channels: number of filters in the first convolutional layer (double of the value is taken from then on)
4040
in_channels: number of input channels
41+
pooling_method: pooling method to be applied before each discriminator after the first.
42+
If None, the number of layers is multiplied by the number of discriminators.
4143
out_channels: number of output channels in each discriminator
4244
kernel_size: kernel size of the convolution layers
4345
activation: activation layer type
@@ -52,10 +54,11 @@ class MultiScalePatchDiscriminator(nn.Sequential):
5254
def __init__(
5355
self,
5456
num_d: int,
55-
num_layers_d: int,
57+
num_layers_d: int | list[int],
5658
spatial_dims: int,
5759
num_channels: int,
5860
in_channels: int,
61+
pooling_method: str = None,
5962
out_channels: int = 1,
6063
kernel_size: int = 4,
6164
activation: str | tuple = (Act.LEAKYRELU, {"negative_slope": 0.2}),
@@ -67,31 +70,67 @@ def __init__(
6770
) -> None:
6871
super().__init__()
6972
self.num_d = num_d
73+
if isinstance(num_layers_d, int) and pooling_method is None:
74+
# if pooling_method is None, calculate the number of layers for each discriminator by multiplying by the number of discriminators
75+
num_layers_d = [num_layers_d * i for i in range(1, num_d + 1)]
76+
elif isinstance(num_layers_d, int) and pooling_method is not None:
77+
# if pooling_method is not None, the number of layers is the same for all discriminators
78+
num_layers_d = [num_layers_d] * num_d
7079
self.num_layers_d = num_layers_d
71-
self.num_channels = num_channels
80+
assert (
81+
len(self.num_layers_d) == self.num_d
82+
), f"MultiScalePatchDiscriminator: num_d {num_d} must match the number of num_layers_d. {num_layers_d}"
83+
7284
self.padding = tuple([int((kernel_size - 1) / 2)] * spatial_dims)
85+
86+
if pooling_method is None:
87+
pool = None
88+
else:
89+
pool = get_pool_layer(
90+
(pooling_method, {"kernel_size": kernel_size, "stride": 2, 'padding': self.padding}), spatial_dims=spatial_dims
91+
)
92+
self.num_channels = num_channels
7393
for i_ in range(self.num_d):
74-
num_layers_d_i = self.num_layers_d * (i_ + 1)
94+
num_layers_d_i = self.num_layers_d[i_]
7595
output_size = float(minimum_size_im) / (2**num_layers_d_i)
7696
if output_size < 1:
7797
raise AssertionError(
7898
"Your image size is too small to take in up to %d discriminators with num_layers = %d."
7999
"Please reduce num_layers, reduce num_D or enter bigger images." % (i_, num_layers_d_i)
80100
)
81-
subnet_d = PatchDiscriminator(
82-
spatial_dims=spatial_dims,
83-
num_channels=self.num_channels,
84-
in_channels=in_channels,
85-
out_channels=out_channels,
86-
num_layers_d=num_layers_d_i,
87-
kernel_size=kernel_size,
88-
activation=activation,
89-
norm=norm,
90-
bias=bias,
91-
padding=self.padding,
92-
dropout=dropout,
93-
last_conv_kernel_size=last_conv_kernel_size,
94-
)
101+
if i_ == 0 or pool is None:
102+
subnet_d = PatchDiscriminator(
103+
spatial_dims=spatial_dims,
104+
num_channels=self.num_channels,
105+
in_channels=in_channels,
106+
out_channels=out_channels,
107+
num_layers_d=num_layers_d_i,
108+
kernel_size=kernel_size,
109+
activation=activation,
110+
norm=norm,
111+
bias=bias,
112+
padding=self.padding,
113+
dropout=dropout,
114+
last_conv_kernel_size=last_conv_kernel_size,
115+
)
116+
else:
117+
subnet_d = nn.Sequential(
118+
*[pool] * i_,
119+
PatchDiscriminator(
120+
spatial_dims=spatial_dims,
121+
num_channels=self.num_channels,
122+
in_channels=in_channels,
123+
out_channels=out_channels,
124+
num_layers_d=num_layers_d_i,
125+
kernel_size=kernel_size,
126+
activation=activation,
127+
norm=norm,
128+
bias=bias,
129+
padding=self.padding,
130+
dropout=dropout,
131+
last_conv_kernel_size=last_conv_kernel_size,
132+
),
133+
)
95134

96135
self.add_module("discriminator_%d" % i_, subnet_d)
97136

tests/test_patch_gan.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,65 @@
5858
[(1, 1, 32, 64, 32), (1, 1, 4, 8, 4)],
5959
[4, 7],
6060
]
61+
TEST_3D_POOL = [
62+
{
63+
"num_d": 2,
64+
"num_layers_d": 3,
65+
"spatial_dims": 3,
66+
"num_channels": 8,
67+
"in_channels": 3,
68+
"out_channels": 1,
69+
"kernel_size": 3,
70+
"pooling_method": "max",
71+
"activation": "LEAKYRELU",
72+
"norm": "instance",
73+
"bias": False,
74+
"dropout": 0.1,
75+
"minimum_size_im": 256,
76+
},
77+
torch.rand([1, 3, 256, 512, 256]),
78+
[(1, 1, 32, 64, 32), (1, 1, 16, 32, 16)],
79+
[4, 4],
80+
]
81+
TEST_2D_POOL = [
82+
{
83+
"num_d": 4,
84+
"num_layers_d": 3,
85+
"spatial_dims": 2,
86+
"num_channels": 8,
87+
"in_channels": 3,
88+
"out_channels": 1,
89+
"kernel_size": 3,
90+
"pooling_method": "avg",
91+
"activation": "LEAKYRELU",
92+
"norm": "instance",
93+
"bias": False,
94+
"dropout": 0.1,
95+
"minimum_size_im": 256,
96+
},
97+
torch.rand([1, 3, 256, 512]),
98+
[(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16), (1, 1, 4, 8)],
99+
[4, 4, 4, 4],
100+
]
101+
TEST_LAYER_LIST = [
102+
{
103+
"num_d": 3,
104+
"num_layers_d": [3,4,5],
105+
"spatial_dims": 2,
106+
"num_channels": 8,
107+
"in_channels": 3,
108+
"out_channels": 1,
109+
"kernel_size": 3,
110+
"activation": "LEAKYRELU",
111+
"norm": "instance",
112+
"bias": False,
113+
"dropout": 0.1,
114+
"minimum_size_im": 256,
115+
},
116+
torch.rand([1, 3, 256, 512]),
117+
[(1, 1, 32, 64), (1, 1, 16, 32), (1, 1, 8, 16)],
118+
[4, 5, 6],
119+
]
61120
TEST_TOO_SMALL_SIZE = [
62121
{
63122
"num_d": 2,
@@ -74,9 +133,24 @@
74133
"minimum_size_im": 256,
75134
}
76135
]
136+
TEST_MISMATCHED_NUM_LAYERS = [
137+
{
138+
"num_d": 5,
139+
"num_layers_d": [3,4,5],
140+
"spatial_dims": 2,
141+
"num_channels": 8,
142+
"in_channels": 3,
143+
"out_channels": 1,
144+
"kernel_size": 3,
145+
"activation": "LEAKYRELU",
146+
"norm": "instance",
147+
"bias": False,
148+
"dropout": 0.1,
149+
"minimum_size_im": 256,
150+
}
151+
]
77152

78-
CASES = [TEST_2D, TEST_3D]
79-
153+
CASES = [TEST_2D, TEST_3D, TEST_3D_POOL, TEST_2D_POOL, TEST_LAYER_LIST]
80154

81155
class TestPatchGAN(unittest.TestCase):
82156
@parameterized.expand(CASES)
@@ -93,6 +167,10 @@ def test_too_small_shape(self):
93167
with self.assertRaises(AssertionError):
94168
MultiScalePatchDiscriminator(**TEST_TOO_SMALL_SIZE[0])
95169

170+
def test_mismatched_num_layers(self):
171+
with self.assertRaises(AssertionError):
172+
MultiScalePatchDiscriminator(**TEST_MISMATCHED_NUM_LAYERS[0])
173+
96174
def test_script(self):
97175
net = MultiScalePatchDiscriminator(
98176
num_d=2,

tutorials/generative/2d_spade_gan/2d_spade_vae.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,4 +357,4 @@ def feature_loss(input_features_disc_fake, input_features_disc_real, lambda_feat
357357
# + [markdown] pycharm={"name": "#%%"}
358358
# **Conclusion**: from early on, the network shows the capability of discern between the different semantic layers. To achieve good image quality, more images and training time are needed (to avoid overfitting, seen in some loss plots of previous example), as well as thorough optimisation, such as establishing an adversarial schedule that makes sure that the discriminator and generator and the discriminator are trained only when their performance does not exceed a certain limit.
359359
#
360-
# -
360+
# -

0 commit comments

Comments
 (0)