-
Notifications
You must be signed in to change notification settings - Fork 290
Add Swin Transformer Backbone #2298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add Swin Transformer Backbone #2298
Conversation
@AdonaiVera - are you still working on this? |
Hi @abheesht17 |
Awesome, thank you! :) |
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The code introduces a Swin Transformer backbone network, including core layers and unit tests. The changes look good overall, but there are some opportunities to improve efficiency and readability, particularly in the SwinTransformerBackbone
and SwinTransformerBlock
classes. Addressing these points will enhance the code's maintainability and performance.
def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): | ||
shape = ops.shape(tensor) | ||
B = shape[0] | ||
L = shape[1] | ||
C = shape[2] | ||
H_float = ops.sqrt(ops.cast(L, x.dtype)) | ||
H = ops.cast(H_float, "int32") | ||
W = H | ||
tensor = ops.reshape(tensor, (B, H, W, C)) | ||
return norm_layer(tensor) | ||
|
||
x_reshaped = keras.layers.Lambda(reshape_and_norm)(x) | ||
features.append(x_reshaped) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reshape_and_norm
function is defined inside the loop, which means it will be re-defined for each stage. This is inefficient. It should be defined outside the loop to avoid re-definition. Also, the shape information is available, so it's better to use that instead of hardcoding the reshape operation with sqrt and casts, which can be error-prone and less readable. Finally, it's better to return the reshaped tensor directly instead of using a Lambda layer.
def reshape_and_norm(tensor, norm_layer=self.norm_layers[i]): | |
shape = ops.shape(tensor) | |
B = shape[0] | |
L = shape[1] | |
C = shape[2] | |
H_float = ops.sqrt(ops.cast(L, x.dtype)) | |
H = ops.cast(H_float, "int32") | |
W = H | |
tensor = ops.reshape(tensor, (B, H, W, C)) | |
return norm_layer(tensor) | |
x_reshaped = keras.layers.Lambda(reshape_and_norm)(x) | |
features.append(x_reshaped) | |
def reshape_and_norm(tensor, norm_layer): | |
shape = ops.shape(tensor) | |
B = shape[0] | |
L = shape[1] | |
C = shape[2] | |
H = image_shape[0] // self.patch_size | |
W = image_shape[1] // self.patch_size | |
tensor = ops.reshape(tensor, (B, H, W, C)) | |
return norm_layer(tensor) | |
# Forward pass | |
features = [] | |
for i, stage in enumerate(self.stages): | |
x = stage(x) | |
x_reshaped = reshape_and_norm(x, self.norm_layers[i]) | |
features.append(x_reshaped) |
img_mask = np.zeros((1, H, W, 1), dtype=np.int32) | ||
cnt = 0 | ||
h_slices = [ | ||
(0, H // 2), | ||
(H // 2, H - self.shift_size), | ||
(H - self.shift_size, H), | ||
] | ||
w_slices = [ | ||
(0, W // 2), | ||
(W // 2, W - self.shift_size), | ||
(W - self.shift_size, W), | ||
] | ||
for h in h_slices: | ||
for w in w_slices: | ||
img_mask[:, h[0] : h[1], w[0] : w[1], :] = cnt | ||
cnt += 1 | ||
img_mask = ops.convert_to_tensor(img_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The img_mask
is initialized with np.zeros
and then converted to a tensor. It's better to initialize it directly as a Keras tensor to avoid unnecessary conversions and potential compatibility issues with different backends. Also, the division by 2 (e.g., H // 2
) might lead to incorrect mask creation if H
or W
is not even. It's better to use ops.cast(H / 2, 'int32')
to ensure correct behavior for both even and odd dimensions.
img_mask = np.zeros((1, H, W, 1), dtype=np.int32) | |
cnt = 0 | |
h_slices = [ | |
(0, H // 2), | |
(H // 2, H - self.shift_size), | |
(H - self.shift_size, H), | |
] | |
w_slices = [ | |
(0, W // 2), | |
(W // 2, W - self.shift_size), | |
(W - self.shift_size, W), | |
] | |
for h in h_slices: | |
for w in w_slices: | |
img_mask[:, h[0] : h[1], w[0] : w[1], :] = cnt | |
cnt += 1 | |
img_mask = ops.convert_to_tensor(img_mask) | |
img_mask = ops.zeros((1, H, W, 1), dtype="int32") | |
cnt = 0 | |
h_slices = [ | |
(0, ops.cast(H / 2, 'int32')), | |
(ops.cast(H / 2, 'int32'), H - self.shift_size), | |
(H - self.shift_size, H), | |
] | |
w_slices = [ | |
(0, ops.cast(W / 2, 'int32')), | |
(ops.cast(W / 2, 'int32'), W - self.shift_size), | |
(W - self.shift_size, W), | |
] | |
for h in h_slices: | |
for w in w_slices: | |
img_mask = ops.slice_update(img_mask, [0, h[0], w[0], 0], ops.ones((1, h[1] - h[0], w[1] - w[0], 1), dtype='int32') * cnt) | |
cnt += 1 | |
# img_mask = ops.convert_to_tensor(img_mask) |
attn_mask = ops.expand_dims(mask_windows, 1) - ops.expand_dims( | ||
mask_windows, 2 | ||
) | ||
attn_mask = ops.where(attn_mask != 0, -100.0, 0.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The attention mask is created using ops.expand_dims
and then compared to 0. This can be simplified by directly comparing the original mask_windows
to each other and using the result to create the attention mask. This avoids the need for expanding dimensions and improves readability.
attn_mask = ops.cast(ops.expand_dims(mask_windows, 1) != ops.expand_dims(mask_windows, 2), dtype='float32') * -100.0
pad_values = ((0, 0), (0, H % 2), (0, W % 2), (0, 0)) | ||
x = ops.pad(x, pad_values) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Padding is applied using ops.pad
. It's important to ensure that the padding values are correctly specified for all backends. Consider using a more explicit padding mode like 'CONSTANT'
with a value of 0 to avoid potential issues with different padding behaviors across backends.
x = ops.pad(x, pad_values, mode='CONSTANT', constant_values=0)
Add Swin Transformer Backbone
This PR adds
SwinTransformerBackbone
, based on the paper Swin Transformer: Hierarchical Vision Transformer using Shifted Windows. It is the first step to support Grounding DINO inkeras-hub
, where Swin is the main image encoder.Related Work
This idea has been discussed in previous threads:
keras-cv
issue #2114 — community request to support Grounding DINO, which requires Swin.keras-hub
issue #2117 — A request to support Swin-UNETR, a model originally designed for 3D medical image segmentation using Swin Transformers for effective feature extraction.Included in this PR
SwinTransformerBackbone
modelPatchEmbedding
,WindowAttention
, etc.)🧪 Current Status
This PR is still a draft. I’m finishing:
float16
andfloat32
Let me know if this direction makes sense, or if you have any comments or suggestions. Thanks!