-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrainval_toolkit.py
More file actions
103 lines (88 loc) · 3.11 KB
/
trainval_toolkit.py
File metadata and controls
103 lines (88 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import sys
import os
import logging
import argparse
import exp_configs
from src import datasets_loader, hf_trainer
from src.training_args import parse_args
from src.constants import RESULTS_FNAME, GFG_DATA_PATH, MAX_VALID_DATA_ROW_COUNT
from haven import haven_wizard as hw
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout)],
level=logging.INFO,
)
logger = logging.getLogger(__name__)
def train(exp_dict, savedir, args):
"""
exp_dict: dictionary defining the hyperparameters of the experiment
savedir: the directory where the experiment will be saved
args: arguments passed through the command line
"""
# Create data loaders and model
train_data = datasets_loader.get_dataset(
dataset_name=args.train_data_name,
path_to_cache=args.data_path,
split="train",
maximum_raw_length=exp_dict["maximum_raw_length"],
)
gfg_test_data = datasets_loader.get_dataset( # Geeks4Geeks data
dataset_name="gfg",
path_to_cache=GFG_DATA_PATH,
split="test",
maximum_raw_length=exp_dict["maximum_raw_length"],
maximum_row_cout=MAX_VALID_DATA_ROW_COUNT,
)
collate_fn = datasets_loader.Collator(
tokenizer_path=exp_dict["tokenizer_path"],
maximum_length=exp_dict["maximum_input_length"],
mlm_masking_probability=exp_dict["mlm_masking_probability"],
contrastive_masking_probability=exp_dict["contrastive_masking_probability"],
ignore_contrastive_loss_data=exp_dict["alpha"] == 1.0,
)
exp_dict["vocab_size"] = collate_fn.vocabulary_size
exp_dict["pad_token_id"] = collate_fn.pad_token_id
trainer = hf_trainer.get_trainer(
exp_dict=exp_dict,
savedir=savedir,
max_steps=args.steps,
train_dataset=train_data,
valid_dataset=gfg_test_data,
collate_fn=collate_fn,
log_every=args.log_every,
wandb_entity_name=args.wandb_entity_name,
wandb_project_name=args.wandb_project_name,
wandb_run_name=args.wandb_run_name,
wandb_log_grads=args.wandb_log_gradients,
local_rank=args.local_rank,
deepspeed_cfg_path=args.deepspeed,
)
trainer.train(
resume_from_checkpoint=any(
dir.startswith("checkpoint") for dir in os.listdir(savedir)
)
)
logging.info("Experiment done\n")
if __name__ == "__main__":
args, others = parse_args()
try:
args.local_rank = int(os.environ["LOCAL_RANK"])
except KeyError:
args.local_rank = 0
# Choose Job Scheduler
job_config = None
if args.job_scheduler == "toolkit":
import job_configs
job_config = job_configs.JOB_CONFIG[args.exp_group]
# Run experiments and create results file
hw.run_wizard(
func=train,
exp_list=exp_configs.EXP_GROUPS[args.exp_group],
savedir_base=args.savedir_base,
reset=args.reset,
job_config=job_config,
results_fname=RESULTS_FNAME,
python_binary_path=args.python_binary,
args=args,
)