Skip to content

Commit d0b5217

Browse files
committed
update Simplenet constructors docstring and remove commented codes
1 parent 851c141 commit d0b5217

File tree

1 file changed

+112
-38
lines changed
  • ImageNet/training_scripts/imagenet_training/timm/models

1 file changed

+112
-38
lines changed

ImageNet/training_scripts/imagenet_training/timm/models/simplenet.py

Lines changed: 112 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,12 @@
1515
Official Pythorch impl at https://github.com/Coderx7/SimpleNet_Pytorch
1616
Seyyed Hossein Hasanpour
1717
"""
18-
# import os
1918
import math
2019

2120
import torch
2221
import torch.nn as nn
2322
import torch.nn.functional as F
2423

25-
# from torch.hub import download_url_to_file
26-
2724
from typing import Union, Tuple, List, Dict, Any, cast, Optional
2825

2926
from .helpers import build_model_with_cfg
@@ -79,10 +76,10 @@ def _cfg(url="", **kwargs):
7976
"simplenetv1_5m_m2": _cfg(
8077
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0-alpha/simv1_5m_m2-324ba7cc.pth"
8178
),
82-
"simplenetv1_m1_9m": _cfg(
79+
"simplenetv1_9m_m1": _cfg(
8380
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0-alpha/simv1_9m_m1-00000000.pth"
8481
),
85-
"simplenetv1_m2_9m": _cfg(
82+
"simplenetv1_9m_m2": _cfg(
8683
url="https://github.com/Coderx7/SimpleNet_Pytorch/releases/download/v1.0.0-alpha/simv1_9m_m2-00000000.pth"
8784
),
8885
}
@@ -106,15 +103,21 @@ def __init__(
106103
):
107104
"""Instantiates a SimpleNet model. SimpleNet is comprised of the most basic building blocks of a CNN architecture.
108105
It uses basic principles to maximize the network performance both in terms of feature representation and speed without
109-
resorting to complex design or operators.
106+
resorting to complex design or operators.
110107
111108
Args:
112109
num_classes (int, optional): number of classes. Defaults to 1000.
113110
in_chans (int, optional): number of input channels. Defaults to 3.
114111
scale (float, optional): scale of the architecture width. Defaults to 1.0.
115112
network_idx (int, optional): the network index indicating the 5 million or 8 million version(0 and 1 respectively). Defaults to 0.
116-
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
117-
you can choose between 0 and 4. (note for imagenet use 1-4). Defaults to 2.
113+
mode (int, optional): stride mode of the architecture. specifies how fast the input shrinks.
114+
This is used for larger input sizes such as the 224x224 in imagenet training where the
115+
input size incurs a lot of overhead if not downsampled properly.
116+
you can choose between 0 meaning no change and 4. where each number denotes a specific
117+
downsampling strategy. For imagenet use 1-4.
118+
the larger the stride mode, the higher accuracy and the slower
119+
the network gets. stride mode 1 is the fastest and achives very good accuracy.
120+
Defaults to 2.
118121
drop_rates (Dict[int,float], optional): custom drop out rates specified per layer.
119122
each rate should be paired with the corrosponding layer index(pooling and cnn layers are counted only). Defaults to {}.
120123
"""
@@ -333,22 +336,23 @@ def set_grad_checkpointing(self, enable=True):
333336
def get_classifier(self):
334337
return self.classifier
335338

336-
def reset_classifier(self, num_classes, network_idx=0, scale=1.0):
339+
def reset_classifier(self, num_classes: int):
337340
self.num_classes = num_classes
338-
self.classifier = nn.Linear(round(self.cfg[self.networks[network_idx]][-1][1] * scale), num_classes)
341+
self.classifier = nn.Linear(round(self.cfg[self.networks[self.network_idx]][-1][1] * self.scale), num_classes)
339342

340343
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
341-
x = self.features(x)
342-
x = F.max_pool2d(x, kernel_size=x.size()[2:])
343-
x = F.dropout2d(x, self.last_dropout_rate, training=self.training)
344-
x = x.view(x.size(0), -1)
345-
return x
344+
return self.features(x)
346345

347346
def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
348-
x = self.forward_features(x)
349-
return x if pre_logits else self.classifier(x)
347+
x_ = self.forward_features(x)
348+
if pre_logits:
349+
return x
350+
else:
351+
x = F.max_pool2d(x, kernel_size=x.size()[2:])
352+
x = F.dropout2d(x, self.last_dropout_rate, training=self.training)
353+
x = x.view(x.size(0), -1)
354+
return self.classifier(x)
350355

351-
!Test this after this change, and update the pure pytorch version and cifar10 versions as well- test classification test extensively again
352356
def _gen_simplenet(
353357
model_variant: str = "simplenetv1_m2",
354358
num_classes: int = 1000,
@@ -371,26 +375,28 @@ def _gen_simplenet(
371375
**kwargs,
372376
)
373377
model = build_model_with_cfg(SimpleNet, model_variant, pretrained, **model_args)
374-
375-
# model = SimpleNet(num_classes, in_chans, scale=scale, network_idx=network_idx, mode=mode, drop_rates=drop_rates)
376-
# if pretrained:
377-
# cfg = default_cfgs.get(model_variant, None)
378-
# if cfg is None:
379-
# raise Exception(f"Unknown model variant ('{model_variant}') specified!")
380-
# url = cfg["url"]
381-
# checkpoint_filename = url.split("/")[-1]
382-
# checkpoint_path = f"tmp/{checkpoint_filename}"
383-
# print(f"saving in checkpoint_path:{checkpoint_path}")
384-
# if not os.path.exists(checkpoint_path):
385-
# os.makedirs("tmp")
386-
# download_url_to_file(url, checkpoint_path)
387-
# checkpoint = torch.load(checkpoint_path, map_location="cpu",)
388-
# model.load_state_dict(checkpoint)
389378
return model
390379

391380

392381
@register_model
393382
def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
383+
"""Generic simplenet model builder. by default it returns `simplenetv1_5m_m2` model
384+
but specifying different arguments such as `netidx`, `scale` or `mode` will result in
385+
the corrosponding network variant.
386+
387+
when pretrained is specified, if the combination of settings resemble any known variants
388+
specified in the `default_cfg`, their respective pretrained weights will be loaded, otherwise
389+
an exception will be thrown denoting Unknown model variant being specified.
390+
391+
Args:
392+
pretrained (bool, optional): loads the model with pretrained weights only if the model is a known variant specified in default_cfg. Defaults to False.
393+
394+
Raises:
395+
Exception: if pretrained is used with an unknown/custom model variant and exception is raised.
396+
397+
Returns:
398+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
399+
"""
394400
num_classes = kwargs.get("num_classes", 1000)
395401
in_chans = kwargs.get("in_chans", 3)
396402
scale = kwargs.get("scale", 1.0)
@@ -414,11 +420,7 @@ def simplenet(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
414420
config = f"small_m{mode}_05"
415421
else:
416422
config = f"m{mode}_{scale:.2f}".replace(".", "")
417-
418-
if network_idx == 0:
419-
model_variant = f"simplenetv1_{config}"
420-
else:
421-
model_variant = f"simplenetv1_{config}"
423+
model_variant = f"simplenetv1_{config}"
422424

423425
cfg = default_cfgs.get(model_variant, None)
424426
if cfg is None:
@@ -477,55 +479,127 @@ def remove_network_settings(kwargs: Dict[str,Any]) -> Dict[str,Any]:
477479
# imagenet models
478480
@register_model
479481
def simplenetv1_small_m1_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
482+
"""Creates a small variant of simplenetv1_5m, with 1.5m parameters. This uses m1 stride mode
483+
which makes it the fastest variant available.
484+
485+
Args:
486+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
487+
488+
Returns:
489+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
490+
"""
480491
model_variant = "simplenetv1_small_m1_05"
481492
model_args = remove_network_settings(kwargs)
482493
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=1, pretrained=pretrained, **model_args)
483494

484495

485496
@register_model
486497
def simplenetv1_small_m2_05(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
498+
"""Creates a second small variant of simplenetv1_5m, with 1.5m parameters. This uses m2 stride mode
499+
which makes it the second fastest variant available.
500+
501+
Args:
502+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
503+
504+
Returns:
505+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
506+
"""
487507
model_variant = "simplenetv1_small_m2_05"
488508
model_args = remove_network_settings(kwargs)
489509
return _gen_simplenet(model_variant, scale=0.5, network_idx=0, mode=2, pretrained=pretrained, **model_args)
490510

491511

492512
@register_model
493513
def simplenetv1_small_m1_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
514+
"""Creates a third small variant of simplenetv1_5m, with 3m parameters. This uses m1 stride mode
515+
which makes it the third fastest variant available.
516+
517+
Args:
518+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
519+
520+
Returns:
521+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
522+
"""
494523
model_variant = "simplenetv1_small_m1_075"
495524
model_args = remove_network_settings(kwargs)
496525
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=1, pretrained=pretrained, **model_args)
497526

498527

499528
@register_model
500529
def simplenetv1_small_m2_075(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
530+
"""Creates a forth small variant of simplenetv1_5m, with 3m parameters. This uses m2 stride mode
531+
which makes it the forth fastest variant available.
532+
533+
Args:
534+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
535+
536+
Returns:
537+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
538+
"""
501539
model_variant = "simplenetv1_small_m2_075"
502540
model_args = remove_network_settings(kwargs)
503541
return _gen_simplenet(model_variant, scale=0.75, network_idx=0, mode=2, pretrained=pretrained, **model_args)
504542

505543

506544
@register_model
507545
def simplenetv1_5m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
546+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m1 stride mode
547+
which makes it a fast and performant model.
548+
549+
Args:
550+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
551+
552+
Returns:
553+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
554+
"""
508555
model_variant = "simplenetv1_5m_m1"
509556
model_args = remove_network_settings(kwargs)
510557
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=1, pretrained=pretrained, **model_args)
511558

512559

513560
@register_model
514561
def simplenetv1_5m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
562+
"""Creates the base simplement model known as simplenetv1_5m, with 5m parameters. This variant uses m2 stride mode
563+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
564+
565+
Args:
566+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
567+
568+
Returns:
569+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
570+
"""
515571
model_variant = "simplenetv1_5m_m2"
516572
model_args = remove_network_settings(kwargs)
517573
return _gen_simplenet(model_variant, scale=1.0, network_idx=0, mode=2, pretrained=pretrained, **model_args)
518574

519575

520576
@register_model
521577
def simplenetv1_9m_m1(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
578+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m1 stride mode
579+
which makes it run faster.
580+
581+
Args:
582+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
583+
584+
Returns:
585+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
586+
"""
522587
model_variant = "simplenetv1_9m_m1"
523588
model_args = remove_network_settings(kwargs)
524589
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=1, pretrained=pretrained, **model_args)
525590

526591

527592
@register_model
528593
def simplenetv1_9m_m2(pretrained: bool = False, **kwargs: Any) -> SimpleNet:
594+
"""Creates a variant of the simplenetv1_5m, with 9m parameters. This variant uses m2 stride mode
595+
which makes it a bit more performant model compared to the m1 variant of the same variant at the expense of a bit slower inference.
596+
597+
Args:
598+
pretrained (bool, optional): loads the model with pretrained weights. Defaults to False.
599+
600+
Returns:
601+
SimpleNet: a SimpleNet model instance is returned upon successful instantiation.
602+
"""
529603
model_variant = "simplenetv1_9m_m2"
530604
model_args = remove_network_settings(kwargs)
531605
return _gen_simplenet(model_variant, scale=1.0, network_idx=1, mode=2, pretrained=pretrained, **model_args)

0 commit comments

Comments
 (0)