-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtraining.py
executable file
·456 lines (402 loc) · 17.1 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
#!/usr/bin/env python3
'''The training loop and loss function. Also implements some auxiliary
functions such as automatic logging, etc.'''
from __future__ import annotations
import csv
import logging
import sys
from abc import abstractmethod
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from itertools import count
from pathlib import Path
from typing import (Any, Dict, Iterator, List, Optional, Protocol, Tuple, Type,
TypeVar)
import click
import haiku as hk
import jax
import jax.numpy as jnp
import jmp
import optax
from chex import Array, ArrayTree, PRNGKey
from einops import rearrange
from torch.utils.data import DataLoader
if __name__ == '__main__':
# If the module is executed we need to add the parent module to the discoverable imports
sys.path.append('.')
from minivae import common, data, nn
logger = logging.getLogger(common.NAME)
T = TypeVar('T')
class TrainingConfig(nn.ModelConfig, Protocol):
'''Configuration for training.'''
batch_size: int
use_half_precision: bool
loss_scale_period: Optional[int]
initial_loss_scale_log2: Optional[int]
peak_learning_rate: float
end_learning_rate: float
warmup_steps: int
total_steps: Optional[int]
weight_decay: float
alpha: float
@classmethod
@abstractmethod
def from_yaml(cls: Type[T], path: Path) -> T:
raise NotImplementedError
@abstractmethod
def to_yaml(self: T, path: Path) -> T:
raise NotImplementedError
@dataclass
class TelemetryData:
'''Data to be logged during training.'''
step: int
epoch: int
params: ArrayTree
opt_state: optax.OptState
loss_scale: jmp.LossScale
config: TrainingConfig
rngs: hk.PRNGSequence
gradients: ArrayTree
gradients_finite: bool
loss: Array
rec_loss: Array
kl_loss: Array
@common.consistent_axes
def train(config: TrainingConfig,
params: ArrayTree,
opt_state: optax.OptState,
dataloader: DataLoader,
rngs: hk.PRNGSequence,
loss_scale: Optional[jmp.LossScale] = None,
step: int = 0,
) -> Iterator[TelemetryData]:
'''Train the model, yielding telemetry data at each step.'''
# Preparations
policy = get_policy(config)
loss_scale = get_loss_scale(config, step) if loss_scale is None else loss_scale
train_step_jit = jax.pmap(partial(train_step, config=config, axis_name='device'),
axis_name='device',
donate_argnums=5)
# Helper function for dealing with multiple devices
device_count = jax.device_count()
logger.info(f'Devices found: {device_count}.')
# Broadcast components across devices
params = broadcast_to_devices(params)
opt_state = broadcast_to_devices(opt_state)
loss_scale = broadcast_to_devices(loss_scale)
# Training loop
for epoch in count():
for samples in dataloader:
common.assert_shape(samples, 'B H W C')
samples = policy.cast_to_compute(samples)
# Split samples and RNG between devices
samples = rearrange(samples, '(d b) ... -> d b ...', d=device_count)
rng = jax.random.split(next(rngs), num=device_count)
params, opt_state, loss_scale, telemetry_dict = train_step_jit(
samples, params, opt_state, loss_scale, rng)
yield TelemetryData(
step=step,
epoch=epoch,
params=get_from_first_device(params),
opt_state=get_from_first_device(opt_state),
loss_scale=get_from_first_device(loss_scale),
config=config,
rngs=rngs,
gradients=get_from_first_device(telemetry_dict['gradients']),
loss=jnp.mean(telemetry_dict['loss']),
rec_loss=jnp.mean(telemetry_dict['rec_loss']),
kl_loss=jnp.mean(telemetry_dict['kl_loss']),
gradients_finite=telemetry_dict['gradients_finite'].all())
step += 1
logger.info(f'Epoch {epoch + 1:,} finished')
def train_step(samples: Array,
params: ArrayTree,
opt_state: optax.OptState,
loss_scale: jmp.LossScale,
rng: PRNGKey,
*,
config: TrainingConfig,
axis_name: str,
) -> Tuple[ArrayTree,
optax.OptState,
jmp.LossScale,
Dict[str, Any]]:
# Preparations
common.assert_shape(samples, 'B H W C')
loss_hk = hk.transform(partial(loss_fn, config=config))
grad_fn = jax.grad(loss_hk.apply, has_aux=True)
optimizer = get_optimizer(config)
# Execution
gradients, telemetry_dict = grad_fn(params, rng, samples)
gradients = jax.lax.pmean(gradients, axis_name=axis_name)
gradients = loss_scale.unscale(gradients)
gradients_finite = jmp.all_finite(gradients)
loss_scale = loss_scale.adjust(gradients_finite)
updates, new_opt_state = optimizer.update(gradients, opt_state, params)
new_params = optax.apply_updates(params, updates)
# Only actually update the params and opt_state if all gradients were finite
opt_state, params = jmp.select_tree(
gradients_finite,
(new_opt_state, new_params),
(opt_state, params))
return (params,
opt_state,
loss_scale,
dict(telemetry_dict,
gradients=gradients,
gradients_finite=gradients_finite))
def loss_fn(samples: Array,
*,
config: TrainingConfig,
) -> Tuple[Array, Dict[str, Any]]:
model = nn.VAE.from_config(config)
common.assert_shape(samples, 'B H W C',
B=config.batch_size,
H=config.shape[0],
W=config.shape[1],
C=config.shape[2])
output: nn.VAEOutput = model(samples, is_training=True)
alpha = config.alpha
loss = alpha * output.reconstruction_loss + (1 - alpha) * output.kl_loss
return loss, dict(rec_loss=output.reconstruction_loss,
kl_loss=output.kl_loss,
loss=loss)
def get_optimizer(config: TrainingConfig) -> optax.GradientTransformation:
'''Get the optimizer with linear warmup and cosine decay.'''
return optax.adamw(get_learning_rate_schedule(config),
weight_decay=config.weight_decay)
def get_policy(config: TrainingConfig) -> jmp.Policy:
'''Get and set the policy for mixed precision training.'''
# The VAE always uses full precision
vae_policy = jmp.get_policy('params=f32,compute=f32,output=f32')
hk.mixed_precision.set_policy(nn.VAE, vae_policy)
# The Encoder and Decoder can use half precision internally
half_policy = jmp.get_policy('params=f32,compute=f16,output=f32'
if config.use_half_precision else
'params=f32,compute=f32,output=f32')
hk.mixed_precision.set_policy(nn.VAEEncoder, half_policy)
hk.mixed_precision.set_policy(nn.VAEDecoder, half_policy)
# LayerNorms use full precision internally and can use half precision for output
ln_policy = jmp.get_policy('params=f32,compute=f32,output=f16'
if config.use_half_precision else
'params=f32,compute=f32,output=f32')
hk.mixed_precision.set_policy(hk.LayerNorm, ln_policy)
return vae_policy
def get_loss_scale(config: TrainingConfig,
step: int,
) -> jmp.LossScale:
'''Get the loss scale for mixed precision training.'''
if config.use_half_precision:
msg = 'initial_loss_scale_log2 must be set for mixed precision training.'
assert config.initial_loss_scale_log2 is not None, msg
msg = 'loss_scale_period must be set for mixed precision training.'
assert config.loss_scale_period is not None, msg
scale = jmp.DynamicLossScale(2. ** jnp.asarray(config.initial_loss_scale_log2),
counter=jnp.asarray(step % config.loss_scale_period),
period=config.loss_scale_period)
else:
scale = jmp.NoOpLossScale()
return scale
def get_learning_rate_schedule(config: TrainingConfig) -> optax.Schedule:
'''Get the learning rate schedule with linear warmup and optional cosine decay.'''
if config.total_steps is not None:
lr_schedule = optax.warmup_cosine_decay_schedule(
init_value=0.,
peak_value=config.peak_learning_rate,
warmup_steps=config.warmup_steps,
decay_steps=config.total_steps - config.warmup_steps,
end_value=config.end_learning_rate,
)
else:
schedules = [
optax.linear_schedule(
init_value=0.,
end_value=config.peak_learning_rate,
transition_steps=config.warmup_steps),
optax.constant_schedule(config.peak_learning_rate),
]
lr_schedule = optax.join_schedules(schedules, [config.warmup_steps])
return lr_schedule
def get_optimizer_state(config: TrainingConfig,
params: ArrayTree,
) -> optax.OptState:
'''Get the optimizer state.'''
optimizer = get_optimizer(config)
opt_state = optimizer.init(params)
opt_state_n = hk.data_structures.tree_size(opt_state)
opt_state_mb = round(hk.data_structures.tree_bytes(opt_state) / 1e6, 2)
logger.info(f'Optimizer state: {opt_state_n:,} ({opt_state_mb:.2f} MB)')
return opt_state
def broadcast_to_devices(obj: T) -> T:
device_count = jax.device_count()
fn = lambda x: (jnp.broadcast_to(x, (device_count, *x.shape))
if isinstance(x, Array) else
x)
return jax.tree_util.tree_map(fn, obj)
def get_from_first_device(obj: T) -> T:
fn = lambda x: x[0] if isinstance(x, Array) else x
return jax.tree_util.tree_map(fn, obj)
def concat_from_devices(obj: T) -> T:
fn = lambda x: (rearrange(x, 'd b ... -> (d b) ...')
if isinstance(x, Array) else
x)
return jax.tree_util.tree_map(fn, obj)
def autosave(telemetry_iter: Iterator[TelemetryData],
frequency: int,
path: Path,
) -> Iterator[TelemetryData]:
'''Save the model parameters and optimizer state etc. at regular intervals.'''
for telemetry in telemetry_iter:
if not isinstance(telemetry.config, common.YamlConfig):
raise ValueError('The config must be a YamlConfig to be saved.')
if telemetry.step % frequency == 0:
common.save_checkpoint(path,
config=telemetry.config,
params=telemetry.params,
opt_state=telemetry.opt_state,
rngs=telemetry.rngs,
loss_scale=telemetry.loss_scale,
step=telemetry.step)
yield telemetry
def autolog(telemetry_iter: Iterator[TelemetryData],
frequency: int,
) -> Iterator[TelemetryData]:
'''Log the telemetry data at the specified frequency.'''
loss_history = []
rec_loss_history = []
kl_loss_history = []
for telemetry in telemetry_iter:
loss_history.append(telemetry.loss)
rec_loss_history.append(telemetry.rec_loss)
kl_loss_history.append(telemetry.kl_loss)
if telemetry.step % frequency == 0 and loss_history:
mean_loss = jnp.mean(jnp.asarray(loss_history))
mean_rec_loss = jnp.mean(jnp.asarray(rec_loss_history))
mean_kl_loss = jnp.mean(jnp.asarray(kl_loss_history))
logger.info(f'Step: {telemetry.step:,}'
f' | loss: {mean_loss:.4f}'
f' | rec: {mean_rec_loss:.4f}'
f' | kl: {mean_kl_loss:.4f}')
loss_history.clear()
rec_loss_history.clear()
kl_loss_history.clear()
yield telemetry
def log_to_csv(telemetry_iter: Iterator[TelemetryData],
path: Path,
) -> Iterator[TelemetryData]:
'''Log the telemetry data to a CSV file.'''
lr_sched = None
path.parent.mkdir(parents=True, exist_ok=True)
did_exist = path.exists()
if not did_exist:
path.touch()
with path.open('a') as f:
writer = csv.DictWriter(f, fieldnames=[
'time', 'step', 'epoch', 'loss', 'rec_loss', 'kl_loss', 'learning_rate'])
if not did_exist:
writer.writeheader()
for telemetry in telemetry_iter:
if lr_sched is None:
lr_sched = get_learning_rate_schedule(telemetry.config)
writer.writerow(dict(time=datetime.now().isoformat(),
step=telemetry.step,
epoch=telemetry.epoch,
loss=telemetry.loss,
rec_loss=telemetry.rec_loss,
kl_loss=telemetry.kl_loss,
learning_rate=lr_sched(telemetry.step)))
yield telemetry
class Config(common.YamlConfig):
# Training config
batch_size: int
use_half_precision: bool
loss_scale_period: Optional[int]
initial_loss_scale_log2: Optional[int]
peak_learning_rate: float
end_learning_rate: float
warmup_steps: int
total_steps: Optional[int]
weight_decay: float
alpha: float
# Model config
encoder_sizes: List[int]
encoder_strides: List[int]
decoder_sizes: List[int]
decoder_strides: List[int]
latent_size: int
dropout: float
# Data config
dataset_path: Path
shape: Tuple[int, int, int]
# DataLoader config
num_workers: int
def get_cli() -> click.Group:
'''Get the command line interface for this module.'''
cli = common.get_cli_group('training')
@cli.command('train')
@click.option('--config-path', '-c', type=Path, default=None,
help='Path to the configuration file')
@click.option('--load-from', '-l', type=Path, default=None,
help='Path to a checkpoint to resume training')
@click.option('--save-path', '-o', type=Path, default=None,
help='Path to save checkpoints automatically')
@click.option('--save-frequency', '-f', type=int, default=1000,
help='Frequency at which to save checkpoints automatically')
@click.option('--log-frequency', type=int, default=10,
help='Frequency at which to log metrics automatically')
@click.option('--csv-path', type=Path, default=None,
help='Path to save metrics in a CSV file')
@click.option('--stop-at', type=int, default=None,
help='Stop training after this many steps')
@click.option('--seed', type=int, default=None, help='Random seed')
def cli_train(config_path: Optional[Path],
load_from: Optional[Path],
save_path: Optional[Path],
save_frequency: int,
log_frequency: int,
csv_path: Optional[Path],
stop_at: Optional[int],
seed: Optional[int],
) -> None:
'''Train a VAE.'''
if config_path is None and load_from is None:
raise ValueError('Either a configuration file or a checkpoint must be provided')
if config_path is not None and load_from is not None:
raise ValueError('Only one of configuration file or checkpoint must be provided')
if config_path is not None:
config = Config.from_yaml(config_path)
rngs = common.get_rngs(seed)
params = nn.VAE.get_params(config, next(rngs))
opt_state = get_optimizer_state(config, params)
step = 0
loss_scale = None
else:
assert load_from is not None
checkpoint = common.load_checkpoint(load_from, config_class=Config)
config = checkpoint['config']
rngs = checkpoint['rngs']
params = checkpoint['params']
opt_state = checkpoint['opt_state']
step = checkpoint['step']
loss_scale = checkpoint['loss_scale']
dataloader = data.LMDBDataset.from_config(config).get_dataloader_from_config(config)
telemetry_iter = train(config=config,
params=params,
opt_state=opt_state,
dataloader=dataloader,
rngs=rngs,
loss_scale=loss_scale,
step=step)
if save_path is not None:
telemetry_iter = autosave(telemetry_iter, save_frequency, save_path)
if csv_path is not None:
telemetry_iter = log_to_csv(telemetry_iter, csv_path)
telemetry_iter = autolog(telemetry_iter, log_frequency)
i = stop_at if stop_at is not None else -1
while i != 0:
next(telemetry_iter)
i -= 1
return cli
if __name__ == '__main__':
get_cli()()