18
18
19
19
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
20
20
from timm .layers import trunc_normal_tf_ , DropPath , LayerNorm2d , Mlp , SelectAdaptivePool2d , create_conv2d , \
21
- use_fused_attn
21
+ use_fused_attn , NormMlpClassifierHead , ClassifierHead
22
22
from ._builder import build_model_with_cfg
23
23
from ._features_fx import register_notrace_module
24
24
from ._manipulate import named_apply , checkpoint_seq
@@ -375,13 +375,23 @@ def __init__(
375
375
self .stages = nn .Sequential (* stages )
376
376
377
377
self .num_features = dims [- 1 ]
378
- self .norm_pre = norm_layer (self .num_features ) if head_norm_first else nn .Identity ()
379
- self .head = nn .Sequential (OrderedDict ([
380
- ('global_pool' , SelectAdaptivePool2d (pool_type = global_pool )),
381
- ('norm' , nn .Identity () if head_norm_first else norm_layer (self .num_features )),
382
- ('flatten' , nn .Flatten (1 ) if global_pool else nn .Identity ()),
383
- ('drop' , nn .Dropout (self .drop_rate )),
384
- ('fc' , nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ())]))
378
+ if head_norm_first :
379
+ self .norm_pre = norm_layer (self .num_features )
380
+ self .head = ClassifierHead (
381
+ self .num_features ,
382
+ num_classes ,
383
+ pool_type = global_pool ,
384
+ drop_rate = self .drop_rate ,
385
+ )
386
+ else :
387
+ self .norm_pre = nn .Identity ()
388
+ self .head = NormMlpClassifierHead (
389
+ self .num_features ,
390
+ num_classes ,
391
+ pool_type = global_pool ,
392
+ drop_rate = self .drop_rate ,
393
+ norm_layer = norm_layer ,
394
+ )
385
395
386
396
named_apply (partial (_init_weights , head_init_scale = head_init_scale ), self )
387
397
@@ -406,10 +416,7 @@ def get_classifier(self):
406
416
return self .head .fc
407
417
408
418
def reset_classifier (self , num_classes = 0 , global_pool = None ):
409
- if global_pool is not None :
410
- self .head .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
411
- self .head .flatten = nn .Flatten (1 ) if global_pool else nn .Identity ()
412
- self .head .fc = nn .Linear (self .num_features , num_classes ) if num_classes > 0 else nn .Identity ()
419
+ self .head .reset (num_classes , global_pool )
413
420
414
421
def forward_features (self , x ):
415
422
x = self .stem (x )
@@ -418,12 +425,7 @@ def forward_features(self, x):
418
425
return x
419
426
420
427
def forward_head (self , x , pre_logits : bool = False ):
421
- # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
422
- x = self .head .global_pool (x )
423
- x = self .head .norm (x )
424
- x = self .head .flatten (x )
425
- x = self .head .drop (x )
426
- return x if pre_logits else self .head .fc (x )
428
+ return self .head (x , pre_logits = True ) if pre_logits else self .head (x )
427
429
428
430
def forward (self , x ):
429
431
x = self .forward_features (x )
0 commit comments