-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathconfig.py
29 lines (23 loc) · 1.1 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
class Config(object):
"""
配置类
"""
# 设备 ####################################################################
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# DEVICE = torch.device('cpu')
DEVICE = torch.device('cuda:6')
# 超参数 ####################################################################
TRAIN_BATCH_SIZE = 4 # batch大小
LR = 0.003 # 学习率
LR_MIN = 1e-6 # 最小学习率
WEIGHT_DECAY = 0.0001
EPOCHS = 50 # 训练次数
# 数据集 ####################################################################
DATASETS_ROOT = '/root/private/torch_datasets' # Pytorch数据集根目录
# 数据处理 ####################################################################
IMAGE_BASE = '/root/data/LaneSeg/Image_Data' # image文件的根目录
LABEL_BASE = '/root/data/LaneSeg/Gray_Label' # label文件的根目录
TRAIN_RATE = 0.7 # 数据集划分,训练集占整个数据集的比例
VALID_RATE = 0.2 # 数据集划分,验证集占整个数据集的比例
pass