Skip to content

Commit 1bbb2f2

Browse files
committed
update seed
1 parent 691a074 commit 1bbb2f2

File tree

4 files changed

+57
-11
lines changed

4 files changed

+57
-11
lines changed

nets/attention.py

+11-3
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,24 @@ def __init__(self, channel, reduction=16):
9797
self.sigmoid_w = nn.Sigmoid()
9898

9999
def forward(self, x):
100+
# batch_size, c, h, w
100101
_, _, h, w = x.size()
101102

103+
# batch_size, c, h, w => batch_size, c, h, 1 => batch_size, c, 1, h
102104
x_h = torch.mean(x, dim = 3, keepdim = True).permute(0, 1, 3, 2)
105+
# batch_size, c, h, w => batch_size, c, 1, w
103106
x_w = torch.mean(x, dim = 2, keepdim = True)
104-
107+
108+
# batch_size, c, 1, w cat batch_size, c, 1, h => batch_size, c, 1, w + h
109+
# batch_size, c, 1, w + h => batch_size, c / r, 1, w + h
105110
x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))
106-
111+
112+
# batch_size, c / r, 1, w + h => batch_size, c / r, 1, h and batch_size, c / r, 1, w
107113
x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)
108-
114+
115+
# batch_size, c / r, 1, h => batch_size, c / r, h, 1 => batch_size, c, h, 1
109116
s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
117+
# batch_size, c / r, 1, w => batch_size, c, 1, w
110118
s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
111119

112120
out = x * s_h.expand_as(x) * s_w.expand_as(x)

requirements.txt

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
torch
2+
torchvision
3+
tensorboard
14
scipy==1.2.1
25
numpy==1.17.0
36
matplotlib==3.1.2
47
opencv_python==4.1.2.30
5-
torch==1.2.0
6-
torchvision==0.4.0
78
tqdm==4.60.0
89
Pillow==8.2.0
9-
h5py==2.10.0
10+
h5py==2.10.0

train.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#-------------------------------------#
44
import datetime
55
import os
6+
from functools import partial
67

78
import numpy as np
89
import torch
@@ -18,7 +19,8 @@
1819
weights_init)
1920
from utils.callbacks import EvalCallback, LossHistory
2021
from utils.dataloader import YoloDataset, yolo_dataset_collate
21-
from utils.utils import get_anchors, get_classes, show_config
22+
from utils.utils import (get_anchors, get_classes, seed_everything,
23+
show_config, worker_init_fn)
2224
from utils.utils_fit import fit_one_epoch
2325

2426
'''
@@ -43,6 +45,11 @@
4345
# 没有GPU可以设置成False
4446
#---------------------------------#
4547
Cuda = True
48+
#----------------------------------------------#
49+
# Seed 用于固定随机种子
50+
# 使得每次独立训练都可以获得一样的结果
51+
#----------------------------------------------#
52+
seed = 11
4653
#---------------------------------------------------------------------#
4754
# distributed 用于指定是否使用单机多卡分布式运行
4855
# 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
@@ -265,6 +272,7 @@
265272
train_annotation_path = '2007_train.txt'
266273
val_annotation_path = '2007_val.txt'
267274

275+
seed_everything(seed)
268276
#------------------------------------------------------#
269277
# 设置用到的显卡
270278
#------------------------------------------------------#
@@ -280,6 +288,7 @@
280288
else:
281289
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
282290
local_rank = 0
291+
rank = 0
283292

284293
#----------------------------------------------------#
285294
# 获取classes和anchor
@@ -482,9 +491,11 @@
482491
shuffle = True
483492

484493
gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
485-
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
494+
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler,
495+
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
486496
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
487-
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
497+
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler,
498+
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
488499

489500
#----------------------#
490501
# 记录eval的map曲线
@@ -533,9 +544,11 @@
533544
batch_size = batch_size // ngpus_per_node
534545

535546
gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
536-
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler)
547+
drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler,
548+
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
537549
gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
538-
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler)
550+
drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler,
551+
worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed))
539552

540553
UnFreeze_flag = True
541554

utils/utils.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
import random
2+
13
import numpy as np
4+
import torch
25
from PIL import Image
36

47
#---------------------------------------------------------#
@@ -57,6 +60,27 @@ def get_lr(optimizer):
5760
for param_group in optimizer.param_groups:
5861
return param_group['lr']
5962

63+
#---------------------------------------------------#
64+
# 设置种子
65+
#---------------------------------------------------#
66+
def seed_everything(seed=11):
67+
random.seed(seed)
68+
np.random.seed(seed)
69+
torch.manual_seed(seed)
70+
torch.cuda.manual_seed(seed)
71+
torch.cuda.manual_seed_all(seed)
72+
torch.backends.cudnn.deterministic = True
73+
torch.backends.cudnn.benchmark = False
74+
75+
#---------------------------------------------------#
76+
# 设置Dataloader的种子
77+
#---------------------------------------------------#
78+
def worker_init_fn(worker_id, rank, seed):
79+
worker_seed = rank + seed
80+
random.seed(worker_seed)
81+
np.random.seed(worker_seed)
82+
torch.manual_seed(worker_seed)
83+
6084
def preprocess_input(image):
6185
image /= 255.0
6286
return image

0 commit comments

Comments
 (0)