Skip to content

Commit 2247993

Browse files
committed
增加shufflenet和mobilenet
1 parent d0b21a5 commit 2247993

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

main.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ def get_net(net_name, weight_path=None):
4040
net = models.densenet121(pretrained=pretrain)
4141
elif net_name in ['inception']:
4242
net = models.inception_v3(pretrained=pretrain)
43+
elif net_name in ['mobilenet_v2']:
44+
net = models.mobilenet_v2(pretrained=pretrain)
45+
elif net_name in ['shufflenet_v2']:
46+
net = models.shufflenet_v2_x1_0(pretrained=pretrain)
4347
else:
4448
raise ValueError('invalid network name:{}'.format(net_name))
4549
# 加载指定路径的权重参数

0 commit comments

Comments
 (0)