-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathtrain.py
136 lines (106 loc) · 4.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# coding: utf-8
# @File: train.py
# @Author: HE D.H.
# @Email: [email protected]
# @Time: 2020/10/10 17:14:07
# @Description:
import os
import torch
import torch.nn as nn
from transformers import BertTokenizer, AdamW, BertConfig
from torch.utils.data import DataLoader
from model import BertClassifier
from dataset import CNewsDataset
from tqdm import tqdm
from sklearn import metrics
def main():
# 参数设置
model_path = r'D:/Workspace/Python/pretrained-models/bert-base-chinese/'
data_path = r'D:/Workspace/Python/cnews/'
batch_size = 4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 10
learning_rate = 5e-6 #Learning Rate不宜太大
tokenizer = BertTokenizer.from_pretrained(model_path)
# 获取到dataset
train_dataset = CNewsDataset(data_path + 'cnews.train.txt', tokenizer)
valid_dataset = CNewsDataset(data_path + 'cnews.val.txt', tokenizer)
#test_dataset = CNewsDataset(data_path + 'cnews.test.txt', tokenizer)
# 生成Batch
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
#test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 读取BERT的配置文件
bert_config = BertConfig.from_pretrained(model_path)
num_labels = len(train_dataset.labels)
# 初始化模型
model = BertClassifier(bert_config, num_labels).to(device)
# 优化器
optimizer = AdamW(model.parameters(), lr=learning_rate)
# 损失函数
criterion = nn.CrossEntropyLoss()
best_f1 = 0
for epoch in range(1, epochs+1):
losses = 0 # 损失
accuracy = 0 # 准确率
model.train()
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
train_bar = tqdm(train_dataloader, ncols=100)
for input_ids, token_type_ids, attention_mask, label_id in train_bar:
# 梯度清零
model.zero_grad()
train_bar.set_description('Epoch %i train' % epoch)
# 传入数据,调用model.forward()
output = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
token_type_ids=token_type_ids.to(device),
)
# 计算loss
loss = criterion(output, label_id.to(device))
losses += loss.item()
pred_labels = torch.argmax(output, dim=1) # 预测出的label
acc = torch.sum(pred_labels == label_id.to(device)).item() / len(pred_labels) #acc
accuracy += acc
loss.backward()
optimizer.step()
train_bar.set_postfix(loss=loss.item(), acc=acc)
average_loss = losses / len(train_dataloader)
average_acc = accuracy / len(train_dataloader)
print('\tTrain ACC:', average_acc, '\tLoss:', average_loss)
# 验证
model.eval()
losses = 0 # 损失
pred_labels = []
true_labels = []
valid_bar = tqdm(valid_dataloader, ncols=100)
for input_ids, token_type_ids, attention_mask, label_id in valid_bar:
valid_bar.set_description('Epoch %i valid' % epoch)
output = model(
input_ids=input_ids.to(device),
attention_mask=attention_mask.to(device),
token_type_ids=token_type_ids.to(device),
)
loss = criterion(output, label_id.to(device))
losses += loss.item()
pred_label = torch.argmax(output, dim=1) # 预测出的label
acc = torch.sum(pred_label == label_id.to(device)).item() / len(pred_label) #acc
valid_bar.set_postfix(loss=loss.item(), acc=acc)
pred_labels.extend(pred_label.cpu().numpy().tolist())
true_labels.extend(label_id.numpy().tolist())
average_loss = losses / len(valid_dataloader)
print('\tLoss:', average_loss)
# 分类报告
report = metrics.classification_report(true_labels, pred_labels, labels=valid_dataset.labels_id, target_names=valid_dataset.labels)
print('* Classification Report:')
print(report)
# f1 用来判断最优模型
f1 = metrics.f1_score(true_labels, pred_labels, labels=valid_dataset.labels_id, average='micro')
if not os.path.exists('models'):
os.makedirs('models')
# 判断并保存验证集上表现最好的模型
if f1 > best_f1:
best_f1 = f1
torch.save(model.state_dict(), 'models/best_model.pkl')
if __name__ == '__main__':
main()