@@ -70,15 +70,20 @@ def forward(self, *features):
70
70
class DeepLabV3PlusDecoder (nn .Module ):
71
71
def __init__ (
72
72
self ,
73
- encoder_channels : Sequence [int , ...],
73
+ encoder_channels : Sequence [int ],
74
+ encoder_depth : Literal [3 , 4 , 5 ],
74
75
out_channels : int ,
75
76
atrous_rates : Iterable [int ],
76
77
output_stride : Literal [8 , 16 ],
77
78
aspp_separable : bool ,
78
79
aspp_dropout : float ,
79
80
):
80
81
super ().__init__ ()
81
- if output_stride not in {8 , 16 }:
82
+ if encoder_depth not in (3 , 4 , 5 ):
83
+ raise ValueError (
84
+ "Encoder depth should be 3, 4 or 5, got {}." .format (encoder_depth )
85
+ )
86
+ if output_stride not in (8 , 16 ):
82
87
raise ValueError (
83
88
"Output stride should be 8 or 16, got {}." .format (output_stride )
84
89
)
@@ -104,7 +109,14 @@ def __init__(
104
109
scale_factor = 2 if output_stride == 8 else 4
105
110
self .up = nn .UpsamplingBilinear2d (scale_factor = scale_factor )
106
111
107
- highres_in_channels = encoder_channels [- 4 ]
112
+ if encoder_depth == 3 and output_stride == 8 :
113
+ self .highres_input_index = - 2
114
+ elif encoder_depth == 3 or encoder_depth == 4 :
115
+ self .highres_input_index = - 3
116
+ else :
117
+ self .highres_input_index = - 4
118
+
119
+ highres_in_channels = encoder_channels [self .highres_input_index ]
108
120
highres_out_channels = 48 # proposed by authors of paper
109
121
self .block1 = nn .Sequential (
110
122
nn .Conv2d (
@@ -128,7 +140,7 @@ def __init__(
128
140
def forward (self , * features ):
129
141
aspp_features = self .aspp (features [- 1 ])
130
142
aspp_features = self .up (aspp_features )
131
- high_res_features = self .block1 (features [- 4 ])
143
+ high_res_features = self .block1 (features [self . highres_input_index ])
132
144
concat_features = torch .cat ([aspp_features , high_res_features ], dim = 1 )
133
145
fused_features = self .block2 (concat_features )
134
146
return fused_features
0 commit comments