Skip to content

Commit bbe7983

Browse files
authored
Update EdgeNeXt to use ClassifierHead as per ConvNeXt (#2051)
* Update edgenext.py
1 parent 711c5de commit bbe7983

File tree

1 file changed

+20
-18
lines changed

1 file changed

+20
-18
lines changed

timm/models/edgenext.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
2020
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
21-
use_fused_attn
21+
use_fused_attn, NormMlpClassifierHead, ClassifierHead
2222
from ._builder import build_model_with_cfg
2323
from ._features_fx import register_notrace_module
2424
from ._manipulate import named_apply, checkpoint_seq
@@ -375,13 +375,23 @@ def __init__(
375375
self.stages = nn.Sequential(*stages)
376376

377377
self.num_features = dims[-1]
378-
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
379-
self.head = nn.Sequential(OrderedDict([
380-
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
381-
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
382-
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
383-
('drop', nn.Dropout(self.drop_rate)),
384-
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
378+
if head_norm_first:
379+
self.norm_pre = norm_layer(self.num_features)
380+
self.head = ClassifierHead(
381+
self.num_features,
382+
num_classes,
383+
pool_type=global_pool,
384+
drop_rate=self.drop_rate,
385+
)
386+
else:
387+
self.norm_pre = nn.Identity()
388+
self.head = NormMlpClassifierHead(
389+
self.num_features,
390+
num_classes,
391+
pool_type=global_pool,
392+
drop_rate=self.drop_rate,
393+
norm_layer=norm_layer,
394+
)
385395

386396
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
387397

@@ -406,10 +416,7 @@ def get_classifier(self):
406416
return self.head.fc
407417

408418
def reset_classifier(self, num_classes=0, global_pool=None):
409-
if global_pool is not None:
410-
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
411-
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
412-
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
419+
self.head.reset(num_classes, global_pool)
413420

414421
def forward_features(self, x):
415422
x = self.stem(x)
@@ -418,12 +425,7 @@ def forward_features(self, x):
418425
return x
419426

420427
def forward_head(self, x, pre_logits: bool = False):
421-
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
422-
x = self.head.global_pool(x)
423-
x = self.head.norm(x)
424-
x = self.head.flatten(x)
425-
x = self.head.drop(x)
426-
return x if pre_logits else self.head.fc(x)
428+
return self.head(x, pre_logits=True) if pre_logits else self.head(x)
427429

428430
def forward(self, x):
429431
x = self.forward_features(x)

0 commit comments

Comments
 (0)