-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
392 lines (358 loc) · 13.6 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
"""
Main script to run the experiment.
"""
import json
import os
from datetime import datetime
from functools import partial
from typing import Annotated, Any, Callable, Collection, Literal
from pydantic import AfterValidator, BaseModel, ConfigDict, Field
from tap import tapify
import torch
from transformers import AutoModelForCausalLM, BertForMaskedLM, GPT2LMHeadModel
import pretrain_on_test
import cloud
from cloud import do_nothing
try:
from IPython.display import clear_output
except ModuleNotFoundError:
clear_output = do_nothing
LMType = Literal[
"bert", # section 7 of the paper
"gpt2", # section 7 and 8
"mistral-qlora-zero-shot", # section 9
"mistral-qlora-zero-shot-packing", # section 9.1
"mistral-qlora-sft", # causes OOMs. Maybe it's b/c merge_and_unload dequantizes?
# For quick CPU tests
"bert-tiny",
"gpt2-tiny",
"mistral-lora-zero-shot-tiny",
"mistral-lora-zero-shot-packing-tiny",
"mistral-lora-sft-tiny",
"mistral-instruct-lora-sft-tiny",
]
lm_type_to_config_creator: dict[LMType, Callable[[Any], pretrain_on_test.Config]] = {
"bert": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="bert-base-uncased",
model_class_pretrain=BertForMaskedLM,
mlm=True,
mlm_probability=0.15,
pretrain_method="raw-text",
lora_pretrain=False,
classification_method="linear-layer",
lora_classification=False,
max_length=256,
**model_independent_kwargs,
),
"gpt2": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="gpt2",
model_class_pretrain=GPT2LMHeadModel,
pretrain_method="raw-text",
lora_pretrain=False,
classification_method="linear-layer",
lora_classification=False,
max_length=256,
**model_independent_kwargs,
),
"mistral-qlora-zero-shot": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="mistralai/Mistral-7B-v0.3",
requires_hf_login=True,
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
qlora=True,
classification_method="zero-shot",
max_length=512,
**model_independent_kwargs,
),
"mistral-qlora-zero-shot-packing": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="mistralai/Mistral-7B-v0.3",
requires_hf_login=True,
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
qlora=True,
classification_method="zero-shot",
max_length=8192,
pack=True,
**model_independent_kwargs,
),
"mistral-qlora-sft": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="mistralai/Mistral-7B-v0.3",
requires_hf_login=True,
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
classification_method="sft",
lora_classification=True,
qlora=True,
max_length=512,
**model_independent_kwargs,
),
# For quick CPU tests. These are useful for prototyping new LM types
"bert-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-BertModel",
model_class_pretrain=BertForMaskedLM,
mlm=True,
mlm_probability=0.15,
pretrain_method="raw-text",
lora_pretrain=False,
classification_method="linear-layer",
lora_classification=False,
max_length=256,
**model_independent_kwargs,
),
"gpt2-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-gpt2",
model_class_pretrain=GPT2LMHeadModel,
pretrain_method="raw-text",
lora_pretrain=False,
classification_method="linear-layer",
lora_classification=False,
max_length=256,
**model_independent_kwargs,
),
"mistral-lora-zero-shot-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-MistralForCausalLM",
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
classification_method="zero-shot",
lora_classification=True,
max_length=512,
**model_independent_kwargs,
),
"mistral-lora-zero-shot-packing-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-MistralForCausalLM",
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
classification_method="zero-shot",
lora_classification=True,
max_length=8192,
pack=True,
**model_independent_kwargs,
),
"mistral-lora-sft-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="hf-internal-testing/tiny-random-MistralForCausalLM",
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
classification_method="sft",
lora_classification=True,
max_length=512,
**model_independent_kwargs,
),
"mistral-instruct-lora-sft-tiny": lambda **model_independent_kwargs: pretrain_on_test.Config(
model_id="ml6team/tiny-random-mistral-instruct",
model_class_pretrain=AutoModelForCausalLM,
pretrain_method="instructions-with-text",
lora_pretrain=True,
classification_method="sft",
lora_classification=True,
max_length=512,
**model_independent_kwargs,
),
}
def _check_dataset_names(dataset_names: Collection[str] | None) -> list[str]:
if dataset_names is None:
dataset_names = list(
pretrain_on_test.data.hf_dataset_name_to_classification_dataset_info.keys()
)
def remove_owner(dataset_name: str) -> str:
return dataset_name.split("/")[-1]
dataset_names_without_owners = [
remove_owner(dataset_name) for dataset_name in dataset_names
]
if len(set(dataset_names_without_owners)) < len(dataset_names_without_owners):
raise ValueError(
"Some datasets have the same name. They may have different owners. But "
"that's still not allowed."
)
return sorted(dataset_names, key=remove_owner)
DatasetNames = Annotated[list[str] | None, AfterValidator(_check_dataset_names)]
_field_for_config = partial(Field, json_schema_extra={"is_for_config": True})
class Experiment(BaseModel):
"""
Experiment configuration.
"""
model_config = ConfigDict(extra="forbid", frozen=True)
# Pydantic stuff: extra attributes are not allowed, and the object is immutable
lm_type: LMType = Field(
description=(
"Type of language model. *-tiny models have random weights and should only "
"be used for testing"
)
)
run_name: str = Field(
default="",
description=(
"Name of the run, in case it helps you remember what changed. If supplied, "
"this name gets appended to the run ID string: run-{timestamp}-{run_name}"
),
)
dataset_names: DatasetNames = Field(
default=None,
description=(
"Space-separated list of HuggingFace datasets, e.g., "
"ag_news dair-ai/emotion SetFit/enron_spam. "
"By default, all datasets from the paper are used"
),
)
num_subsamples: int = Field(
default=50, description="Number of subsamples to draw from the dataset"
)
num_train: int = Field(
default=100,
description=(
"Number of observations for classification training, i.e., m in the paper"
),
)
num_test: int = Field(
default=200,
description=(
"Number of observations for pretraining and eval, i.e., n in the paper"
),
)
# Model-independent arguments which are passed to the config
per_device_train_batch_size_pretrain: int = _field_for_config(
default=16, description="Batch size for pretraining"
)
per_device_train_batch_size_classification: int = _field_for_config(
default=16, description="Batch size for classification training"
)
per_device_eval_batch_size_classification: int = _field_for_config(
default=64, description="Batch size for classification evaluation"
)
num_train_epochs_classification: int = _field_for_config(
default=3, description="Number of epochs for classification training"
)
num_train_epochs_pretrain: int = _field_for_config(
default=2, description="Number of epochs for pretraining"
)
keep_models: bool = _field_for_config(
default=False,
description=(
"Whether to delete the saved models (base, extra, test) after completing a "
"subsample. Currently, only the last subsample's models are saved. So this "
"only makes sense when num_subsamples=1"
),
# TODO: this trade-off isn't needed for LoRA
)
def run(
experiment: Experiment,
create_logger: cloud.CreateLogger = cloud.create_logger_local,
upload_directory: cloud.UploadDirectory = do_nothing,
) -> str:
"""
Run the experiment.
Parameters
----------
experiment : Experiment
configuration for the experiment
create_logger : cloud.CreateLogger, optional
Callable which takes as input a single argument for the name of the log
group/label/tag, and outputs a `logging.Logger` object. By default, a logger is
created which only logs to stdout.
upload_directory : cloud.UploadDirectory, optional
Callable which takes as input `directory` and `logger` arguments and uploads all
local content in `directory` somewhere else, e.g., S3. By default, nothing is
uploaded.
Returns
-------
str
run ID
"""
# Meta info
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
run_id = (
f"run-{current_time}{'-' + experiment.run_name if experiment.run_name else ''}"
)
# Create logger
logger = create_logger(run_id)
logger.info(f"ID of the run: {run_id}")
logger.info(experiment)
try:
if torch.cuda.is_available():
logger.info("GPU detected.")
else:
logger.info("No GPU detected.")
if "cpu-test" not in experiment.run_name:
raise ValueError(
"No GPU was detected. If this is intentional, please include "
"'cpu-test' somewhere in the run_name argument."
)
# Create results_dir using core settings from the experiment: n and the LM
results_dir = os.path.join(
run_id,
"accuracies",
f"m{experiment.num_train}",
f"n{experiment.num_test}",
experiment.lm_type,
)
# Upload experiment settings
os.makedirs(run_id)
with open(os.path.join(run_id, "experiment.json"), "w") as json_file:
experiment_as_dict = experiment.model_dump()
json.dump(experiment_as_dict, json_file, indent=4)
upload_directory(directory=run_id, logger=logger)
# Create config from experiment
model_independent_attributes = [
field_name
for field_name, field_info in Experiment.model_fields.items()
if (getattr(field_info, "json_schema_extra") or {}).get(
"is_for_config", False
)
]
model_independent_kwargs = {
attr: getattr(experiment, attr) for attr in model_independent_attributes
}
config = lm_type_to_config_creator[experiment.lm_type](
**model_independent_kwargs
)
# Run experiment on each dataset
_ = torch.manual_seed(123)
torch.cuda.manual_seed_all(123)
for dataset_name in experiment.dataset_names:
classification_dataset_info = (
pretrain_on_test.data.hf_dataset_name_to_classification_dataset_info[
dataset_name
]
)
df = pretrain_on_test.data.load_classification_data(
classification_dataset_info
)
clear_output(wait=True)
dataset_dir = pretrain_on_test.experiment.replicate(
df,
classification_dataset_info,
dataset_name,
results_dir,
config,
logger,
num_subsamples=experiment.num_subsamples,
num_train=experiment.num_train,
num_test=experiment.num_test,
)
# Sync w/ cloud
upload_directory(directory=dataset_dir, logger=logger)
except Exception as exception:
try:
msg = f"Encountered an error with dataset {dataset_name}: "
except UnboundLocalError:
msg = ""
logger.error(f"{msg}{exception}", exc_info=True)
raise
return run_id
if __name__ == "__main__":
experiment = tapify(Experiment)
cloud_provider = os.environ.get("PRETRAIN_ON_TEST_CLOUD_PROVIDER")
# Env var b/c it's reasonable to run this script many times in one session. So just
# need to specify the env var once
create_data_handlers = cloud.cloud_provider_to_create_data_handlers[cloud_provider]
data_handlers = create_data_handlers()
run(
experiment,
create_logger=data_handlers.create_logger,
upload_directory=data_handlers.upload_directory,
)