Skip to content

Commit cc482aa

Browse files
authoredNov 29, 2024··
Fix DeepLabV3Plus encoder depth (#986)
* fix issue #377 * modify docstring for upsampling of DeepLabV3Plus * modify type hint and value check
1 parent d490cdf commit cc482aa

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed
 

‎segmentation_models_pytorch/decoders/deeplabv3/decoder.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,20 @@ def forward(self, *features):
7070
class DeepLabV3PlusDecoder(nn.Module):
7171
def __init__(
7272
self,
73-
encoder_channels: Sequence[int, ...],
73+
encoder_channels: Sequence[int],
74+
encoder_depth: Literal[3, 4, 5],
7475
out_channels: int,
7576
atrous_rates: Iterable[int],
7677
output_stride: Literal[8, 16],
7778
aspp_separable: bool,
7879
aspp_dropout: float,
7980
):
8081
super().__init__()
81-
if output_stride not in {8, 16}:
82+
if encoder_depth not in (3, 4, 5):
83+
raise ValueError(
84+
"Encoder depth should be 3, 4 or 5, got {}.".format(encoder_depth)
85+
)
86+
if output_stride not in (8, 16):
8287
raise ValueError(
8388
"Output stride should be 8 or 16, got {}.".format(output_stride)
8489
)
@@ -104,7 +109,14 @@ def __init__(
104109
scale_factor = 2 if output_stride == 8 else 4
105110
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
106111

107-
highres_in_channels = encoder_channels[-4]
112+
if encoder_depth == 3 and output_stride == 8:
113+
self.highres_input_index = -2
114+
elif encoder_depth == 3 or encoder_depth == 4:
115+
self.highres_input_index = -3
116+
else:
117+
self.highres_input_index = -4
118+
119+
highres_in_channels = encoder_channels[self.highres_input_index]
108120
highres_out_channels = 48 # proposed by authors of paper
109121
self.block1 = nn.Sequential(
110122
nn.Conv2d(
@@ -128,7 +140,7 @@ def __init__(
128140
def forward(self, *features):
129141
aspp_features = self.aspp(features[-1])
130142
aspp_features = self.up(aspp_features)
131-
high_res_features = self.block1(features[-4])
143+
high_res_features = self.block1(features[self.highres_input_index])
132144
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
133145
fused_features = self.block2(concat_features)
134146
return fused_features

‎segmentation_models_pytorch/decoders/deeplabv3/model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ class DeepLabV3Plus(SegmentationModel):
129129
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
130130
**callable** and **None**.
131131
Default is **None**
132-
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity
132+
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case
133+
**encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve.
133134
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
134135
on top of encoder if **aux_params** is not **None** (default). Supported params:
135136
- classes (int): A number of classes
@@ -150,7 +151,7 @@ class DeepLabV3Plus(SegmentationModel):
150151
def __init__(
151152
self,
152153
encoder_name: str = "resnet34",
153-
encoder_depth: int = 5,
154+
encoder_depth: Literal[3, 4, 5] = 5,
154155
encoder_weights: Optional[str] = "imagenet",
155156
encoder_output_stride: Literal[8, 16] = 16,
156157
decoder_channels: int = 256,
@@ -177,6 +178,7 @@ def __init__(
177178

178179
self.decoder = DeepLabV3PlusDecoder(
179180
encoder_channels=self.encoder.out_channels,
181+
encoder_depth=encoder_depth,
180182
out_channels=decoder_channels,
181183
atrous_rates=decoder_atrous_rates,
182184
output_stride=encoder_output_stride,

0 commit comments

Comments
 (0)
Please sign in to comment.