17
17
import torch
18
18
import torch .nn as nn
19
19
from monai .networks .blocks import Convolution
20
- from monai .networks .layers import Act
20
+ from monai .networks .layers import Act , get_pool_layer
21
21
22
22
23
23
class MultiScalePatchDiscriminator (nn .Sequential ):
@@ -38,6 +38,8 @@ class MultiScalePatchDiscriminator(nn.Sequential):
38
38
spatial_dims: number of spatial dimensions (1D, 2D etc.)
39
39
num_channels: number of filters in the first convolutional layer (double of the value is taken from then on)
40
40
in_channels: number of input channels
41
+ pooling_method: pooling method to be applied before each discriminator after the first.
42
+ If None, the number of layers is multiplied by the number of discriminators.
41
43
out_channels: number of output channels in each discriminator
42
44
kernel_size: kernel size of the convolution layers
43
45
activation: activation layer type
@@ -52,10 +54,11 @@ class MultiScalePatchDiscriminator(nn.Sequential):
52
54
def __init__ (
53
55
self ,
54
56
num_d : int ,
55
- num_layers_d : int ,
57
+ num_layers_d : int | list [ int ] ,
56
58
spatial_dims : int ,
57
59
num_channels : int ,
58
60
in_channels : int ,
61
+ pooling_method : str = None ,
59
62
out_channels : int = 1 ,
60
63
kernel_size : int = 4 ,
61
64
activation : str | tuple = (Act .LEAKYRELU , {"negative_slope" : 0.2 }),
@@ -67,31 +70,67 @@ def __init__(
67
70
) -> None :
68
71
super ().__init__ ()
69
72
self .num_d = num_d
73
+ if isinstance (num_layers_d , int ) and pooling_method is None :
74
+ # if pooling_method is None, calculate the number of layers for each discriminator by multiplying by the number of discriminators
75
+ num_layers_d = [num_layers_d * i for i in range (1 , num_d + 1 )]
76
+ elif isinstance (num_layers_d , int ) and pooling_method is not None :
77
+ # if pooling_method is not None, the number of layers is the same for all discriminators
78
+ num_layers_d = [num_layers_d ] * num_d
70
79
self .num_layers_d = num_layers_d
71
- self .num_channels = num_channels
80
+ assert (
81
+ len (self .num_layers_d ) == self .num_d
82
+ ), f"MultiScalePatchDiscriminator: num_d { num_d } must match the number of num_layers_d. { num_layers_d } "
83
+
72
84
self .padding = tuple ([int ((kernel_size - 1 ) / 2 )] * spatial_dims )
85
+
86
+ if pooling_method is None :
87
+ pool = None
88
+ else :
89
+ pool = get_pool_layer (
90
+ (pooling_method , {"kernel_size" : kernel_size , "stride" : 2 , 'padding' : self .padding }), spatial_dims = spatial_dims
91
+ )
92
+ self .num_channels = num_channels
73
93
for i_ in range (self .num_d ):
74
- num_layers_d_i = self .num_layers_d * ( i_ + 1 )
94
+ num_layers_d_i = self .num_layers_d [ i_ ]
75
95
output_size = float (minimum_size_im ) / (2 ** num_layers_d_i )
76
96
if output_size < 1 :
77
97
raise AssertionError (
78
98
"Your image size is too small to take in up to %d discriminators with num_layers = %d."
79
99
"Please reduce num_layers, reduce num_D or enter bigger images." % (i_ , num_layers_d_i )
80
100
)
81
- subnet_d = PatchDiscriminator (
82
- spatial_dims = spatial_dims ,
83
- num_channels = self .num_channels ,
84
- in_channels = in_channels ,
85
- out_channels = out_channels ,
86
- num_layers_d = num_layers_d_i ,
87
- kernel_size = kernel_size ,
88
- activation = activation ,
89
- norm = norm ,
90
- bias = bias ,
91
- padding = self .padding ,
92
- dropout = dropout ,
93
- last_conv_kernel_size = last_conv_kernel_size ,
94
- )
101
+ if i_ == 0 or pool is None :
102
+ subnet_d = PatchDiscriminator (
103
+ spatial_dims = spatial_dims ,
104
+ num_channels = self .num_channels ,
105
+ in_channels = in_channels ,
106
+ out_channels = out_channels ,
107
+ num_layers_d = num_layers_d_i ,
108
+ kernel_size = kernel_size ,
109
+ activation = activation ,
110
+ norm = norm ,
111
+ bias = bias ,
112
+ padding = self .padding ,
113
+ dropout = dropout ,
114
+ last_conv_kernel_size = last_conv_kernel_size ,
115
+ )
116
+ else :
117
+ subnet_d = nn .Sequential (
118
+ * [pool ] * i_ ,
119
+ PatchDiscriminator (
120
+ spatial_dims = spatial_dims ,
121
+ num_channels = self .num_channels ,
122
+ in_channels = in_channels ,
123
+ out_channels = out_channels ,
124
+ num_layers_d = num_layers_d_i ,
125
+ kernel_size = kernel_size ,
126
+ activation = activation ,
127
+ norm = norm ,
128
+ bias = bias ,
129
+ padding = self .padding ,
130
+ dropout = dropout ,
131
+ last_conv_kernel_size = last_conv_kernel_size ,
132
+ ),
133
+ )
95
134
96
135
self .add_module ("discriminator_%d" % i_ , subnet_d )
97
136
0 commit comments