Skip to content

Commit 0704c84

Browse files
committed
added openAIAdam optimizer
1 parent 0b73057 commit 0704c84

File tree

4 files changed

+112
-29
lines changed

4 files changed

+112
-29
lines changed

model_py.py

-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import torch
88
import torch.nn as nn
9-
import torch.nn.functional as F
109
from torch.nn.parameter import Parameter
1110

1211
def gelu(x):

opt.py

+90-9
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,104 @@
11
import math
2-
import numpy as np
2+
import torch
3+
from torch.optim import Optimizer
4+
from torch.nn.utils import clip_grad_norm
35

46
def warmup_cosine(x, warmup=0.002):
5-
pass
7+
s = 1 if x <= warmup else 0
8+
return s*(x/warmup) + (1-s)*(0.5 * (1 + torch.cos(math.pi * x)))
69

710
def warmup_constant(x, warmup=0.002):
8-
pass
11+
s = 1 if x <= warmup else 0
12+
return s*(x/warmup) + (1-s)*1
913

1014
def warmup_linear(x, warmup=0.002):
11-
pass
15+
s = 1 if x <= warmup else 0
16+
return (s*(x/warmup) + (1-s))*(1-x)
1217

13-
schedules = {
18+
SCHEDULES = {
1419
'warmup_cosine':warmup_cosine,
1520
'warmup_constant':warmup_constant,
1621
'warmup_linear':warmup_linear,
1722
}
1823

19-
def adam(params, grads, lr, schedule, t_total, b1=0.9, b2=0.999, e=1e-8, l2=0, vector_l2=False, max_grad_norm=-1, **kwargs):
20-
"""
21-
adam with weight decay fix
24+
25+
class OpenAIAdam(Optimizer):
26+
"""Implements Open AI version of Adam algorithm with weight decay fix.
2227
"""
23-
pass
28+
def __init__(self, params, lr, schedule, warmup, t_total,
29+
b1=0.9, b2=0.999, e=1e-8, l2=0,
30+
vector_l2=False, max_grad_norm=-1, **kwargs):
31+
if not 0.0 <= lr:
32+
raise ValueError("Invalid learning rate: {}".format(lr))
33+
if schedule not in SCHEDULES:
34+
raise ValueError("Invalid schedule parameter: {}".format(schedule))
35+
if not 0 <= warmup:
36+
raise ValueError("Invalid warmup: {}".format(warmup))
37+
if not 0.0 <= b1 < 1.0:
38+
raise ValueError("Invalid b1 parameter: {}".format(b1))
39+
if not 0.0 <= b2 < 1.0:
40+
raise ValueError("Invalid b2 parameter: {}".format(b2))
41+
if not 0.0 <= e:
42+
raise ValueError("Invalid epsilon value: {}".format(e))
43+
defaults = dict(lr=lr, schedule=schedule, warmup=warmup, t_total=t_total,
44+
b1=b1, b2=b2, e=e, l2=l2, vector_l2=vector_l2,
45+
max_grad_norm=max_grad_norm)
46+
super(OpenAIAdam, self).__init__(params, defaults)
47+
48+
def step(self, closure=None):
49+
"""Performs a single optimization step.
50+
51+
Arguments:
52+
closure (callable, optional): A closure that reevaluates the model
53+
and returns the loss.
54+
"""
55+
loss = None
56+
if closure is not None:
57+
loss = closure()
58+
59+
for group in self.param_groups:
60+
for p in group['params']:
61+
if p.grad is None:
62+
continue
63+
grad = p.grad.data
64+
if grad.is_sparse:
65+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
66+
67+
state = self.state[p]
68+
69+
# State initialization
70+
if len(state) == 0:
71+
state['step'] = 0
72+
# Exponential moving average of gradient values
73+
state['exp_avg'] = torch.zeros_like(p.data)
74+
# Exponential moving average of squared gradient values
75+
state['exp_avg_sq'] = torch.zeros_like(p.data)
76+
77+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
78+
beta1, beta2 = group['b1'], group['b2']
79+
80+
state['step'] += 1
81+
82+
# Add grad clipping
83+
if group['max_grad_norm'] > 0:
84+
clip_grad_norm(p, group['max_grad_norm'])
85+
86+
# Decay the first and second moment running average coefficient
87+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
88+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
89+
denom = exp_avg_sq.sqrt().add_(group['eps'])
90+
91+
bias_correction1 = 1 - beta1 ** state['step']
92+
bias_correction2 = 1 - beta2 ** state['step']
93+
94+
schedule_fct = SCHEDULES[group['schedule']]
95+
lr_scheduled = group['lr'] * schedule_fct(state['step']/group['t_total'], group['warmup'])
96+
step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
97+
98+
p.data.addcdiv_(-step_size, exp_avg, denom)
99+
100+
# Add weight decay at the end (fixed version)
101+
if (len(p.size()) > 1 or group['vector_l2']) and group['l2'] > 0:
102+
p.data.add_(-lr_scheduled * group['l2'], p.data)
103+
104+
return loss

text_utils.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ def text_standardize(text):
2727
text = text.replace('―', '-')
2828
text = text.replace('…', '...')
2929
text = text.replace('´', "'")
30-
text = re.sub('''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
31-
text = re.sub('\s*\n\s*', ' \n ', text)
32-
text = re.sub('[^\S\n]+', ' ', text)
30+
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
31+
text = re.sub(r'\s*\n\s*', ' \n ', text)
32+
text = re.sub(r'[^\S\n]+', ' ', text)
3333
return text.strip()
3434

3535
class TextEncoder(object):

train.py

+19-16
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,20 @@
1818
from sklearn.metrics import accuracy_score
1919

2020
from model_py import Model, LMHead, ClfHead, load_openai_pretrained_model
21-
from opt import adam, warmup_cosine, warmup_linear, warmup_constant
21+
from opt import OpenAIAdam
2222
from datasets import rocstories
2323
from analysis import rocstories as rocstories_analysis
2424
from text_utils import TextEncoder
2525
from utils import (encode_dataset, flatten, iter_data,
2626
ResultLogger, make_path)
2727

28-
OPT_FNS = {
29-
'adam':adam,
30-
}
31-
32-
LR_SCHEDULES = {
33-
'warmup_cosine':warmup_cosine,
34-
'warmup_linear':warmup_linear,
35-
'warmup_constant':warmup_constant,
36-
}
37-
3828
class LossCompute:
3929
"A Loss compute and train function."
40-
def __init__(self, lm_criterion, clf_criterion, lm_coef):
30+
def __init__(self, lm_criterion, clf_criterion, lm_coef, opt=None):
4131
self.lm_criterion = lm_criterion
4232
self.clf_criterion = clf_criterion
4333
self.lm_coef = lm_coef
34+
self.opt = opt
4435

4536
def __call__(self, X, Y, M, lm_logits, clf_logits):
4637
# Language modeling loss
@@ -53,11 +44,18 @@ def __call__(self, X, Y, M, lm_logits, clf_logits):
5344

5445
# Classification loss
5546
clf_losses = self.clf_criterion(clf_logits, Y)
47+
5648
if self.lm_coef > 0:
5749
train_loss = clf_losses.sum() + self.lm_coef * lm_losses.sum()
5850
else:
5951
train_loss = clf_losses.sum()
60-
return train_loss
52+
53+
train_loss.backward()
54+
if self.opt is not None:
55+
self.opt.step()
56+
self.opt.optimizer.zero_grad()
57+
return train_loss.item()
58+
6159

6260
def transform_roc(X1, X2, X3):
6361
n_batch = len(X1)
@@ -229,7 +227,14 @@ def transform_roc(X1, X2, X3):
229227
model = Model(vocab, args)
230228
lm_head = LMHead(model, args)
231229
clf_head = ClfHead(clf_token, args)
232-
compute_loss = LossCompute(nn.CrossEntropyLoss(reduce=False), nn.CrossEntropyLoss(reduce=False), lm_coef) # TODO check loss functions
230+
231+
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
232+
model_opt = OpenAIAdam(model.parameters(), lr=lr, schedule=lr_schedule,
233+
warmup=lr_warmup, t_total=n_updates_total, b1=b1,
234+
b2=b2, e=e, l2=l2, vector_l2=vector_l2,
235+
max_grad_norm=max_grad_norm)
236+
237+
compute_loss = LossCompute(criterion, criterion, lm_coef, model_opt)
233238
# TODO Initialize model (?)
234239
# TODO add train() and eval()
235240
load_openai_pretrained_model(model, n_ctx, n_special, args)
@@ -258,8 +263,6 @@ def transform_roc(X1, X2, X3):
258263
lm_logits = lm_head(h)
259264
clf_logits = clf_head(h, XMB)
260265
loss = compute_loss(XMB, YMB, MMB, lm_logits, clf_logits)
261-
loss.backward()
262-
263266
n_updates += 1
264267
#if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
265268
# log()

0 commit comments

Comments
 (0)