-
Notifications
You must be signed in to change notification settings - Fork 348
/
Copy path__init__.py
84 lines (77 loc) · 2.31 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.datasets.tokenizer.tiktoken import build_tiktoken_tokenizer
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
from .model import Transformer, TransformerModelArgs
from .parallelize_llama import parallelize_llama
from .pipeline_llama import pipeline_llama
__all__ = [
"parallelize_llama",
"pipeline_llama",
"TransformerModelArgs",
"Transformer",
"llama3_configs",
]
llama3_configs = {
"debugmodel": TransformerModelArgs(
dim=256, n_layers=6, n_heads=16, rope_theta=500000
),
"debugmodel_flex_attn": TransformerModelArgs(
dim=256,
n_layers=6,
n_heads=16,
rope_theta=500000,
use_flex_attn=True,
attn_mask_type="block_causal",
),
"8B": TransformerModelArgs(
dim=4096,
n_layers=32,
n_heads=32,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=1024,
rope_theta=500000,
),
"70B": TransformerModelArgs(
dim=8192,
n_layers=80,
n_heads=64,
n_kv_heads=8,
ffn_dim_multiplier=1.3,
multiple_of=4096,
rope_theta=500000,
),
"405B": TransformerModelArgs(
dim=16384,
n_layers=126,
n_heads=128,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=4096,
rope_theta=500000,
),
}
register_train_spec(
TrainSpec(
name="llama3",
cls=Transformer,
config=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_tiktoken_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
)