forked from illiterate/BertClassifier
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
113 lines (85 loc) · 3.65 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# coding: utf-8
# @File: train.py
# @Author: HE D.H.
# @Email: [email protected]
# @Time: 2020/10/10 17:14:07
# @Description:
import torch
import torch.nn as nn
from transformers import BertTokenizer, AdamW, BertConfig
from torch.utils.data import DataLoader
from model import BertClassifier
from dataset import CNewsDataset
from tqdm import tqdm
def main():
# 参数设置
batch_size = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 10
learning_rate = 5e-6 #Learning Rate不宜太大
# 获取到dataset
train_dataset = CNewsDataset('data/cnews/cnews.train.txt')
valid_dataset = CNewsDataset('data/cnews/cnews.val.txt')
#test_data = load_data('cnews/cnews.test.txt')
# 生成Batch
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
#test_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False)
# 读取BERT的配置文件
bert_config = BertConfig.from_pretrained('bert-base-chinese')
num_labels = len(train_dataset.labels)
# 初始化模型
model = BertClassifier(bert_config, num_labels).to(device)
optimizer = AdamW(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()
best_acc = 0
for epoch in range(1, epochs+1):
losses = 0 # 损失
accuracy = 0 # 准确率
model.train()
train_bar = tqdm(train_dataloader)
for input_ids, token_type_ids, attention_mask, label_id in train_bar:
model.zero_grad()
train_bar.set_description('Epoch %i train' % epoch)
output = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
token_type_ids=token_type_ids.to(device),
)
loss = criterion(output, label_id.to(device))
losses += loss.item()
pred_labels = torch.argmax(output, dim=1) # 预测出的label
acc = torch.sum(pred_labels == label_id.to(device)).item() / len(pred_labels) #acc
accuracy += acc
loss.backward()
optimizer.step()
train_bar.set_postfix(loss=loss.item(), acc=acc)
average_loss = losses / len(train_dataloader)
average_acc = accuracy / len(train_dataloader)
print('\tTrain ACC:', average_acc, '\tLoss:', average_loss)
# 验证
model.eval()
losses = 0 # 损失
accuracy = 0 # 准确率
valid_bar = tqdm(valid_dataloader)
for input_ids, token_type_ids, attention_mask, label_id in valid_bar:
valid_bar.set_description('Epoch %i valid' % epoch)
output = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
token_type_ids=token_type_ids.to(device),
)
loss = criterion(output, label_id.to(device))
losses += loss.item()
pred_labels = torch.argmax(output, dim=1) # 预测出的label
acc = torch.sum(pred_labels == label_id.to(device)).item() / len(pred_labels) #acc
accuracy += acc
valid_bar.set_postfix(loss=loss.item(), acc=acc)
average_loss = losses / len(valid_dataloader)
average_acc = accuracy / len(valid_dataloader)
print('\tValid ACC:', average_acc, '\tLoss:', average_loss)
if average_acc > best_acc:
best_acc = average_acc
torch.save(model.state_dict(), 'models/best_model.pkl')
if __name__ == '__main__':
main()