|
3 | 3 | #-------------------------------------#
|
4 | 4 | import datetime
|
5 | 5 | import os
|
| 6 | +from functools import partial |
6 | 7 |
|
7 | 8 | import numpy as np
|
8 | 9 | import torch
|
|
18 | 19 | weights_init)
|
19 | 20 | from utils.callbacks import EvalCallback, LossHistory
|
20 | 21 | from utils.dataloader import YoloDataset, yolo_dataset_collate
|
21 |
| -from utils.utils import get_anchors, get_classes, show_config |
| 22 | +from utils.utils import (get_anchors, get_classes, seed_everything, |
| 23 | + show_config, worker_init_fn) |
22 | 24 | from utils.utils_fit import fit_one_epoch
|
23 | 25 |
|
24 | 26 | '''
|
|
43 | 45 | # 没有GPU可以设置成False
|
44 | 46 | #---------------------------------#
|
45 | 47 | Cuda = True
|
| 48 | + #----------------------------------------------# |
| 49 | + # Seed 用于固定随机种子 |
| 50 | + # 使得每次独立训练都可以获得一样的结果 |
| 51 | + #----------------------------------------------# |
| 52 | + seed = 11 |
46 | 53 | #---------------------------------------------------------------------#
|
47 | 54 | # distributed 用于指定是否使用单机多卡分布式运行
|
48 | 55 | # 终端指令仅支持Ubuntu。CUDA_VISIBLE_DEVICES用于在Ubuntu下指定显卡。
|
|
265 | 272 | train_annotation_path = '2007_train.txt'
|
266 | 273 | val_annotation_path = '2007_val.txt'
|
267 | 274 |
|
| 275 | + seed_everything(seed) |
268 | 276 | #------------------------------------------------------#
|
269 | 277 | # 设置用到的显卡
|
270 | 278 | #------------------------------------------------------#
|
|
280 | 288 | else:
|
281 | 289 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
282 | 290 | local_rank = 0
|
| 291 | + rank = 0 |
283 | 292 |
|
284 | 293 | #----------------------------------------------------#
|
285 | 294 | # 获取classes和anchor
|
|
482 | 491 | shuffle = True
|
483 | 492 |
|
484 | 493 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
|
485 |
| - drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler) |
| 494 | + drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler, |
| 495 | + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) |
486 | 496 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
|
487 |
| - drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler) |
| 497 | + drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler, |
| 498 | + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) |
488 | 499 |
|
489 | 500 | #----------------------#
|
490 | 501 | # 记录eval的map曲线
|
|
533 | 544 | batch_size = batch_size // ngpus_per_node
|
534 | 545 |
|
535 | 546 | gen = DataLoader(train_dataset, shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
|
536 |
| - drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler) |
| 547 | + drop_last=True, collate_fn=yolo_dataset_collate, sampler=train_sampler, |
| 548 | + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) |
537 | 549 | gen_val = DataLoader(val_dataset , shuffle = shuffle, batch_size = batch_size, num_workers = num_workers, pin_memory=True,
|
538 |
| - drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler) |
| 550 | + drop_last=True, collate_fn=yolo_dataset_collate, sampler=val_sampler, |
| 551 | + worker_init_fn=partial(worker_init_fn, rank=rank, seed=seed)) |
539 | 552 |
|
540 | 553 | UnFreeze_flag = True
|
541 | 554 |
|
|
0 commit comments