-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlaikaLLM.py
113 lines (79 loc) · 4.07 KB
/
laikaLLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import dataclasses
import os
import yaml
from pygit2 import Repository, GitError
from src.data.main import data_main
from src.evaluate.main import eval_main
from src.model.main import model_main
from src.utils import seed_everything, init_wandb, IndentedDumper
from src.yml_parse import parse_yml_config
def pretty_print_configuration(config: dict):
print(" Experiment configuration ".center(80, "*"))
print("\n" + "-" * 80)
print("Environment/General parameters:")
print("-" * 80)
env_var_keys = ("PYTHONHASHSEED", "CUBLAS_WORKSPACE_CONFIG", "git_branch")
env_var_dict = {key: config[key] for key in env_var_keys}
general_dict = config["general_params"]
print(yaml.dump({**env_var_dict, **general_dict}, default_flow_style=False, Dumper=IndentedDumper))
print("-" * 80)
print("Data parameters:")
print("-" * 80)
print(yaml.dump(config["data_params"], default_flow_style=False, Dumper=IndentedDumper))
print("-" * 80)
print("Model parameters:")
print("-" * 80)
print(yaml.dump(config["model_params"], default_flow_style=False, Dumper=IndentedDumper))
print("-" * 80)
print("Eval parameters:")
print("-" * 80)
print(yaml.dump(config["eval_params"], default_flow_style=False, Dumper=IndentedDumper))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Main script to reproduce perform the experiments')
parser.add_argument('-c', '--config', default="params.yml", required=True,
help='The path to the .yml file in which are specified all the experiment parameters')
# parse yml config
args = parser.parse_args()
general_params, data_params, model_params, eval_params = parse_yml_config(args.config)
if general_params.log_wandb:
if 'WANDB_API_KEY' not in os.environ:
raise ValueError('Cannot log run to wandb if environment variable "WANDB_API_KEY" is not present\n'
'Please set the environment variable and add the api key for wandb\n')
if 'WANDB_ENTITY' not in os.environ:
raise ValueError('Cannot log run to wandb if environment variable "WANDB_ENTITY" is not present\n'
'Please set the environment variable and add the entity for wandb logs\n')
# this is the config dict that will be logged to wandb
# apart from the params read from yml file, log env variables needed for reproducibility and
# also the current active branch in which experiment is being performed (if the project is in a git directory)
try:
git_branch = Repository('.').head.shorthand
except GitError:
git_branch = None
config_args = {
"general_params": dataclasses.asdict(general_params),
"data_params": dataclasses.asdict(data_params),
"model_params": dataclasses.asdict(model_params),
"eval_params": dataclasses.asdict(eval_params),
"PYTHONHASHSEED": os.environ.get("PYTHONHASHSEED"),
"CUBLAS_WORKSPACE_CONFIG": os.environ.get("CUBLAS_WORKSPACE_CONFIG"),
"git_branch": git_branch
}
pretty_print_configuration(config_args)
with init_wandb(project=general_params.wandb_project, name=general_params.exp_name, config=config_args,
should_log=general_params.log_wandb):
if not general_params.eval_only:
print(" DATA ".center(80, "*"))
# at start of each main phase, we re-initialize the state
seed_everything(general_params.random_seed)
data_main(general_params, data_params)
print() # simple newline
print(" MODEL ".center(80, "*"))
# at start of each main phase, we re-initialize the state
seed_everything(general_params.random_seed)
model_main(general_params, data_params, model_params)
print() # simple newline
print(" EVAL ".center(80, "*"))
# at start of each main phase, we re-initialize the state
seed_everything(general_params.random_seed)
eval_main(general_params, data_params, model_params, eval_params)