Skip to content

Commit 03c9e74

Browse files
committed
clean up multi gpu logic
1 parent aded2b0 commit 03c9e74

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

model_pytorch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ class LMHead(nn.Module):
177177
def __init__(self, model, cfg):
178178
super(LMHead, self).__init__()
179179
self.n_embd = cfg.n_embd
180-
self.decoder = lambda x: F.linear(x, model.embed.weight) # Tied weights
180+
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
181+
self.decoder.weight = model.embed.weight # Tied weights
181182

182183
def forward(self, h):
183184
# Truncated Language modeling logits (we remove the last token)

train.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def run_epoch():
151151
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
152152
n_updates += 1
153153
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
154-
log()
154+
log(save_dir, desc)
155155

156156

157157
argmax = lambda x: np.argmax(x, 1)
@@ -194,7 +194,6 @@ def run_epoch():
194194
parser.add_argument('--clf_pdrop', type=float, default=0.1)
195195
parser.add_argument('--l2', type=float, default=0.01)
196196
parser.add_argument('--vector_l2', action='store_true')
197-
parser.add_argument('--n_gpu', type=int, default=1)
198197
parser.add_argument('--opt', type=str, default='adam')
199198
parser.add_argument('--afn', type=str, default='gelu')
200199
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
@@ -223,9 +222,11 @@ def run_epoch():
223222
desc = args.desc
224223
data_dir = args.data_dir
225224
log_dir = args.log_dir
225+
submission_dir = args.submission_dir
226226

227-
# torch.device object used throughout this script TODO add gpu setting
228227
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
228+
n_gpu = torch.cuda.device_count()
229+
print("device", device, "n_gpu", n_gpu)
229230

230231
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
231232
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
@@ -259,10 +260,11 @@ def run_epoch():
259260

260261
n_train = len(trY)
261262
n_valid = len(vaY)
262-
n_batch_train = args.n_batch * args.n_gpu
263+
n_batch_train = args.n_batch * n_gpu
263264
n_updates_total = (n_train // n_batch_train) * args.n_iter
264265

265-
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
266+
if n_gpu > 1:
267+
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
266268

267269
criterion = nn.CrossEntropyLoss(reduce=False)
268270
model_opt = OpenAIAdam(dh_model.parameters(),

0 commit comments

Comments
 (0)