Skip to content

Commit 901f860

Browse files
authored
Fix race condiction (#12)
Fix #4
1 parent bfd0ffd commit 901f860

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

basic_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def load_model_emb(args, tokenizer):
7272
### random emb or pre-defined embedding like glove embedding. You can custome your own init here.
7373
model = torch.nn.Embedding(tokenizer.vocab_size, args.hidden_dim)
7474
path_save = '{}/random_emb.torch'.format(args.checkpoint_path)
75+
path_save_ind = path_save + ".done"
7576
if int(os.environ['LOCAL_RANK']) == 0:
7677
if os.path.exists(path_save):
7778
print('reload the random embeddings', model)
@@ -80,8 +81,11 @@ def load_model_emb(args, tokenizer):
8081
print('initializing the random embeddings', model)
8182
torch.nn.init.normal_(model.weight)
8283
torch.save(model.state_dict(), path_save)
84+
os.sync()
85+
with open(path_save_ind, "x") as _:
86+
pass
8387
else:
84-
while not os.path.exists(path_save):
88+
while not os.path.exists(path_save_ind):
8589
time.sleep(1)
8690
print('reload the random embeddings', model)
8791
model.load_state_dict(torch.load(path_save))

0 commit comments

Comments
 (0)