-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain.py
54 lines (52 loc) · 2.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import argparse
import os
import random
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from trainer import trainer_dataset
import os
from networks.vit_seg_modeling_L2HNet import L2HNet
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='Chesapeake', help='experiment_name')
parser.add_argument('--max_epochs', type=int, default=100, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int, default=16, help='batch_size per gpu')
parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate')
parser.add_argument('--seed', type=int, default=1234, help='random seed')
parser.add_argument('--CNN_width', type=int, default=64, help='L2HNet_width_size, default is 64: light mode. Set to 128: normal mode')
parser.add_argument('--savepath', type=str)
parser.add_argument('--gpu', type=str, help='Select GPU number to train' )
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
if __name__ == "__main__":
vit_patches_size=16
img_size=224
cudnn.benchmark = True
cudnn.deterministic = False
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Chesapeake': { # default dataset as a example
'list_dir': '/home/ashelee/project_TransUNet/list/NY_raw.csv', # The path of the *.csv file
'num_classes': 17
}
}# Create a config to your own dataset here
if args.batch_size != 24 and args.batch_size % 6 == 0:
args.base_lr *= args.batch_size / 24
args.num_classes = dataset_config[dataset_name]['num_classes']
args.list_dir = dataset_config[dataset_name]['list_dir']
args.is_pretrain = True
snapshot_path = args.savepath
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
config_vit = CONFIGS_ViT_seg["ViT-B_16"]
config_vit.n_classes = args.num_classes
config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size))
net = ViT_seg(config_vit, backbone=L2HNet(width=args.CNN_width),img_size=img_size, num_classes=config_vit.n_classes).cuda()
net.load_from(weights=np.load(config_vit.pretrained_path))
trainer_dataset(args, net, snapshot_path)