@@ -247,7 +247,6 @@ def conv_with_general_padding(lhs: Array, rhs: Array,
247247 rhs_dilation = rhs_dilation , precision = precision ,
248248 preferred_element_type = preferred_element_type )
249249
250-
251250def _conv_transpose_padding (k , s , padding ):
252251 """Calculate before and after padding for a dim of transposed convolution.
253252
@@ -268,12 +267,15 @@ def _conv_transpose_padding(k, s, padding):
268267 elif padding == 'VALID' :
269268 pad_len = k + s - 2 + max (k - s , 0 )
270269 pad_a = k - 1
270+ elif isinstance (padding , tuple ):
271+ pads = tuple (k - p - 1 for p in padding )
272+ pad_a = pads [0 ]
273+ pad_len = sum (pads )
271274 else :
272- raise ValueError ('Padding mode must be `SAME` or `VALID`.' )
275+ raise ValueError (f"Invalid padding mode: { padding } " )
273276 pad_b = pad_len - pad_a
274277 return pad_a , pad_b
275278
276-
277279def _flip_axes (x , axes ):
278280 """Flip ndarray 'x' along each axis specified in axes tuple."""
279281 for axis in axes :
@@ -297,9 +299,11 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
297299 lhs: a rank `n+2` dimensional input array.
298300 rhs: a rank `n+2` dimensional array of kernel weights.
299301 strides: sequence of `n` integers, sets fractional stride.
300- padding: 'SAME', 'VALID' will set as transpose of corresponding forward
301- conv, or a sequence of `n` integer 2-tuples describing before-and-after
302- padding for each `n` spatial dimension.
302+ padding: 'SAME', 'VALID', or a sequence of `n` integer 2-tuples describing before-and-after
303+ padding for each spatial dimension in the corresponding forward conv. This effectively adds
304+ `dilation * (kernel_size - 1) - padding` zero padding to each side
305+ of the input so that `conv_transpose` becomes the gradient of `conv` when given the same padding
306+ and stride arguments.
303307 rhs_dilation: `None`, or a sequence of `n` integers, giving the
304308 dilation factor to apply in each spatial dimension of `rhs`. RHS dilation
305309 is also known as atrous convolution.
@@ -342,14 +346,13 @@ def conv_transpose(lhs: Array, rhs: Array, strides: Sequence[int],
342346 k_sdims = k_shape [2 :]
343347 # Calculate correct output shape given padding and strides.
344348 pads : str | Sequence [tuple [int , int ]]
345- if isinstance (padding , str ) and padding in {'SAME' , 'VALID' }:
346- if rhs_dilation is None :
347- rhs_dilation = (1 ,) * (rhs .ndim - 2 )
348- effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
349- pads = [_conv_transpose_padding (k , s , padding )
350- for k ,s in zip (effective_k_size , strides )]
351- else :
352- pads = padding
349+ if rhs_dilation is None :
350+ rhs_dilation = (1 ,) * (rhs .ndim - 2 )
351+ effective_k_size = map (lambda k , r : core .dilate_dim (k , r ), k_sdims , rhs_dilation )
352+ if isinstance (padding , str ):
353+ padding = [padding ] * len (strides )
354+ pads = [_conv_transpose_padding (k , s , p )
355+ for k ,s ,p in zip (effective_k_size , strides , padding )]
353356 if transpose_kernel :
354357 # flip spatial dims and swap input / output channel axes
355358 rhs = _flip_axes (rhs , np .array (dn .rhs_spec )[2 :])
0 commit comments