9
9
"""
10
10
11
11
from dataclasses import dataclass
12
- from typing import Any , Dict , Optional , Tuple
12
+ from typing import Any , Dict , Optional , Tuple , List
13
13
14
14
import torch
15
15
import torch .nn as nn
19
19
from litgpt .model import GPT as BaseModel
20
20
from litgpt .model import Block as BaseBlock
21
21
from litgpt .model import CausalSelfAttention as BaseCausalSelfAttention
22
+ from litgpt .kvcache .base import KVCache , KeysAndValues , DefaultKeysAndValues
22
23
23
24
24
25
@dataclass
@@ -29,20 +30,33 @@ class Config(BaseConfig):
29
30
30
31
class GPT (BaseModel ):
31
32
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
32
- def __init__ (self , config : Config ) -> None :
33
+ def __init__ (
34
+ self ,
35
+ config : Config ,
36
+ kv_cache : Optional [List [KVCache ]] = None
37
+ ) -> None :
33
38
nn .Module .__init__ (self )
34
39
assert config .padded_vocab_size is not None
35
40
self .config = config
36
41
42
+ if kv_cache is not None :
43
+ if len (kv_cache ) != config .n_layer :
44
+ raise ValueError (f"kv_cache length { len (kv_cache )} != { config .n_layer } = config.n_layer" )
45
+ for kvc in kv_cache :
46
+ self ._check_kv_cache (config , kvc )
47
+ self ._default_kv_cache = False
48
+ else :
49
+ kv_cache = [None ] * config .n_layer
50
+ self ._default_kv_cache = True
37
51
self .lm_head = nn .Linear (
38
52
config .n_embd , config .padded_vocab_size , bias = config .lm_head_bias
39
53
)
40
54
self .transformer = nn .ModuleDict (
41
55
dict (
42
56
wte = nn .Embedding (config .padded_vocab_size , config .n_embd ),
43
57
h = nn .ModuleList (
44
- Block (config , block_idx )
45
- for block_idx in range ( config . n_layer )
58
+ Block (config , block_idx , kv_cache = kvc )
59
+ for block_idx , kvc in enumerate ( kv_cache )
46
60
),
47
61
ln_f = config .norm_class (config .n_embd , eps = config .norm_eps ),
48
62
)
@@ -62,17 +76,27 @@ def _init_weights(self, module: nn.Module) -> None:
62
76
63
77
64
78
class Block (BaseBlock ):
65
- def __init__ (self , config : Config , block_idx : int ) -> None :
66
- super ().__init__ (config , block_idx )
67
- self .attn = CausalSelfAttention (config , block_idx )
79
+ def __init__ (
80
+ self ,
81
+ config : Config ,
82
+ block_idx : int ,
83
+ kv_cache : Optional [KVCache ] = None ,
84
+ ) -> None :
85
+ super ().__init__ (config , block_idx , kv_cache )
86
+ self .attn = CausalSelfAttention (config , block_idx , kv_cache = kv_cache )
68
87
69
88
70
89
class CausalSelfAttention (BaseCausalSelfAttention ):
71
90
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
72
91
over the adaption prompt."""
73
92
74
- def __init__ (self , config : Config , block_idx : int ) -> None :
75
- super ().__init__ (config , block_idx )
93
+ def __init__ (
94
+ self ,
95
+ config : Config ,
96
+ block_idx : int ,
97
+ kv_cache : Optional [KVCache ] = None ,
98
+ ) -> None :
99
+ super ().__init__ (config , block_idx , kv_cache )
76
100
if block_idx >= config .adapter_start_layer :
77
101
# adapter embedding layer
78
102
self .adapter_wte = nn .Embedding (config .adapter_prompt_length , config .n_embd )
@@ -82,11 +106,16 @@ def __init__(self, config: Config, block_idx: int) -> None:
82
106
self .adapter_kv_cache : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None
83
107
84
108
def scaled_dot_product_attention (
85
- self , q : torch .Tensor , k : torch .Tensor , v : torch .Tensor , mask : Optional [torch .Tensor ] = None
86
- ) -> torch .Tensor :
87
- y = super ().scaled_dot_product_attention (q , k , v , mask )
109
+ self ,
110
+ q : torch .Tensor ,
111
+ k_and_v : KeysAndValues ,
112
+ mask : Optional [torch .Tensor ] = None ,
113
+ is_causal : bool = True ,
114
+ return_scores : bool = False ,
115
+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
116
+ y , scores = super ().scaled_dot_product_attention (q , k_and_v , mask , is_causal , return_scores )
88
117
if self .block_idx < self .config .adapter_start_layer :
89
- return y
118
+ return y , scores
90
119
91
120
aT = self .config .adapter_prompt_length
92
121
if self .adapter_kv_cache is not None :
@@ -110,8 +139,14 @@ def scaled_dot_product_attention(
110
139
111
140
T = q .size (2 )
112
141
amask = torch .ones (T , aT , dtype = torch .bool , device = q .device )
113
- ay = super ().scaled_dot_product_attention (q , ak , av , amask )
114
- return y + self .gating_factor * ay
142
+ a_k_and_v = DefaultKeysAndValues (keys = ak , values = av )
143
+ ay , _ = super ().scaled_dot_product_attention (
144
+ q = q ,
145
+ k_and_v = a_k_and_v ,
146
+ mask = amask ,
147
+ is_causal = False ,
148
+ )
149
+ return y + self .gating_factor * ay , scores
115
150
116
151
def reset_parameters (self ) -> None :
117
152
if hasattr (self , "gating_factor" ):
0 commit comments