Skip to content

Commit ff817a9

Browse files
committed
Support for KV caching and batched inference
1 parent f6031e3 commit ff817a9

30 files changed

+2381
-788
lines changed

Diff for: litgpt/adapter.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from dataclasses import dataclass
12-
from typing import Any, Dict, Optional, Tuple
12+
from typing import Any, Dict, Optional, Tuple, List
1313

1414
import torch
1515
import torch.nn as nn
@@ -19,6 +19,7 @@
1919
from litgpt.model import GPT as BaseModel
2020
from litgpt.model import Block as BaseBlock
2121
from litgpt.model import CausalSelfAttention as BaseCausalSelfAttention
22+
from litgpt.kvcache.base import KVCache, KeysAndValues, DefaultKeysAndValues
2223

2324

2425
@dataclass
@@ -29,20 +30,33 @@ class Config(BaseConfig):
2930

3031
class GPT(BaseModel):
3132
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
32-
def __init__(self, config: Config) -> None:
33+
def __init__(
34+
self,
35+
config: Config,
36+
kv_cache: Optional[List[KVCache]] = None
37+
) -> None:
3338
nn.Module.__init__(self)
3439
assert config.padded_vocab_size is not None
3540
self.config = config
3641

42+
if kv_cache is not None:
43+
if len(kv_cache) != config.n_layer:
44+
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
45+
for kvc in kv_cache:
46+
self._check_kv_cache(config, kvc)
47+
self._default_kv_cache = False
48+
else:
49+
kv_cache = [None] * config.n_layer
50+
self._default_kv_cache = True
3751
self.lm_head = nn.Linear(
3852
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
3953
)
4054
self.transformer = nn.ModuleDict(
4155
dict(
4256
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
4357
h=nn.ModuleList(
44-
Block(config, block_idx)
45-
for block_idx in range(config.n_layer)
58+
Block(config, block_idx, kv_cache=kvc)
59+
for block_idx, kvc in enumerate(kv_cache)
4660
),
4761
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
4862
)
@@ -62,17 +76,27 @@ def _init_weights(self, module: nn.Module) -> None:
6276

6377

6478
class Block(BaseBlock):
65-
def __init__(self, config: Config, block_idx: int) -> None:
66-
super().__init__(config, block_idx)
67-
self.attn = CausalSelfAttention(config, block_idx)
79+
def __init__(
80+
self,
81+
config: Config,
82+
block_idx: int,
83+
kv_cache: Optional[KVCache] = None,
84+
) -> None:
85+
super().__init__(config, block_idx, kv_cache)
86+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
6887

6988

7089
class CausalSelfAttention(BaseCausalSelfAttention):
7190
"""A modification of `litgpt.model.CausalSelfAttention` that adds the attention
7291
over the adaption prompt."""
7392

74-
def __init__(self, config: Config, block_idx: int) -> None:
75-
super().__init__(config, block_idx)
93+
def __init__(
94+
self,
95+
config: Config,
96+
block_idx: int,
97+
kv_cache: Optional[KVCache] = None,
98+
) -> None:
99+
super().__init__(config, block_idx, kv_cache)
76100
if block_idx >= config.adapter_start_layer:
77101
# adapter embedding layer
78102
self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd)
@@ -82,11 +106,16 @@ def __init__(self, config: Config, block_idx: int) -> None:
82106
self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
83107

84108
def scaled_dot_product_attention(
85-
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None
86-
) -> torch.Tensor:
87-
y = super().scaled_dot_product_attention(q, k, v, mask)
109+
self,
110+
q: torch.Tensor,
111+
k_and_v: KeysAndValues,
112+
mask: Optional[torch.Tensor] = None,
113+
is_causal: bool = True,
114+
return_scores: bool = False,
115+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
116+
y, scores = super().scaled_dot_product_attention(q, k_and_v, mask, is_causal, return_scores)
88117
if self.block_idx < self.config.adapter_start_layer:
89-
return y
118+
return y, scores
90119

91120
aT = self.config.adapter_prompt_length
92121
if self.adapter_kv_cache is not None:
@@ -110,8 +139,14 @@ def scaled_dot_product_attention(
110139

111140
T = q.size(2)
112141
amask = torch.ones(T, aT, dtype=torch.bool, device=q.device)
113-
ay = super().scaled_dot_product_attention(q, ak, av, amask)
114-
return y + self.gating_factor * ay
142+
a_k_and_v = DefaultKeysAndValues(keys=ak, values=av)
143+
ay, _ = super().scaled_dot_product_attention(
144+
q=q,
145+
k_and_v=a_k_and_v,
146+
mask=amask,
147+
is_causal=False,
148+
)
149+
return y + self.gating_factor * ay, scores
115150

116151
def reset_parameters(self) -> None:
117152
if hasattr(self, "gating_factor"):

Diff for: litgpt/adapter_v2.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from dataclasses import dataclass
12-
from typing import Any, Dict, Type, Optional
12+
from typing import Any, Dict, Type, Optional, List
1313

1414
import torch
1515
import torch.nn as nn
@@ -22,6 +22,7 @@
2222
from litgpt.adapter import Config as BaseConfig
2323
from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble
2424
from litgpt.utils import map_old_state_dict_weights
25+
from litgpt.kvcache.base import KVCache
2526

2627

2728
@dataclass
@@ -64,20 +65,33 @@ def reset_parameters(self) -> None:
6465

6566
class GPT(BaseModel):
6667
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
67-
def __init__(self, config: Config) -> None:
68+
def __init__(
69+
self,
70+
config: Config,
71+
kv_cache: Optional[List[KVCache]] = None
72+
) -> None:
6873
nn.Module.__init__(self)
6974
assert config.padded_vocab_size is not None
7075
self.config = config
7176

77+
if kv_cache is not None:
78+
if len(kv_cache) != config.n_layer:
79+
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
80+
for kvc in kv_cache:
81+
self._check_kv_cache(config, kvc)
82+
self._default_kv_cache = False
83+
else:
84+
kv_cache = [None] * config.n_layer
85+
self._default_kv_cache = True
7286
self.lm_head = AdapterV2Linear(
7387
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
7488
)
7589
self.transformer = nn.ModuleDict(
7690
dict(
7791
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
7892
h=nn.ModuleList(
79-
Block(config, block_idx)
80-
for block_idx in range(config.n_layer)
93+
Block(config, block_idx, kv_cache=kvc)
94+
for block_idx, kvc in enumerate(kv_cache)
8195
),
8296
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
8397
)
@@ -103,18 +117,28 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa
103117

104118

105119
class Block(BaseBlock):
106-
def __init__(self, config: Config, block_idx: int) -> None:
107-
super().__init__(config, block_idx)
108-
self.attn = CausalSelfAttention(config, block_idx)
120+
def __init__(
121+
self,
122+
config: Config,
123+
block_idx: int,
124+
kv_cache: Optional[KVCache] = None,
125+
) -> None:
126+
super().__init__(config, block_idx, kv_cache)
127+
self.attn = CausalSelfAttention(config, block_idx, kv_cache=kv_cache)
109128
self.mlp = config.mlp_class(config)
110129

111130

112131
class CausalSelfAttention(BaseCausalSelfAttention):
113132
"""A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class"""
114133

115134
# Copy&paste from :class:`model.CausalSelfAttention`
116-
def __init__(self, config: Config, block_idx: int) -> None:
117-
super().__init__(config, block_idx)
135+
def __init__(
136+
self,
137+
config: Config,
138+
block_idx: int,
139+
kv_cache: Optional[KVCache] = None,
140+
) -> None:
141+
super().__init__(config, block_idx, kv_cache)
118142
# key, query, value projections for all heads, but in a batch
119143
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
120144
self.qkv = AdapterV2Linear(

Diff for: litgpt/api.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def generate(
448448
self,
449449
prompt: str,
450450
max_new_tokens: int = 50,
451+
prompt_chunksize: int = 1,
451452
temperature: float = 1.0,
452453
top_k: Optional[int] = None,
453454
top_p: float = 1.0,
@@ -461,6 +462,11 @@ def generate(
461462
model: The model to use.
462463
prompt: The prompt string to use for generating the samples.
463464
max_new_tokens: The maximum number of new tokens to return.
465+
prompt_chunksize: If even the shortest prompt is longer than the KV
466+
cache, prompts are processed in chunks of this size in the
467+
prefill phase. Once the shortest has been processed to the
468+
end, we proceed with chunk size 1.
469+
Defaults to 1, but larger values are recommended for long prompts.
464470
temperature: Scales the predicted logits by 1 / temperature.
465471
top_k: If specified, only sample among the tokens with the k highest probabilities.
466472
top_p: If specified, it represents the cumulative probability threshold to consider in the sampling process.
@@ -500,15 +506,12 @@ def generate(
500506
self.kv_cache_initialized = True
501507

502508
# Dynamically grow the kv cache size if necessary
509+
self.model.clear_kv_cache()
503510
if not self.fixed_kv_cache_size and self.prev_generated_seq_length < max_returned_tokens:
504511
tmp_device = self.model.mask_cache.device
505512
self.model.clear_kv_cache()
506513
self.model.set_kv_cache(batch_size=1, max_seq_length=max_returned_tokens, device=tmp_device)
507514

508-
else:
509-
for block in self.model.transformer.h:
510-
block.attn.kv_cache.reset_parameters()
511-
512515
self.prev_generated_seq_length = max_returned_tokens
513516
self.model.eval()
514517

@@ -517,6 +520,7 @@ def iterator():
517520
model=self.model,
518521
prompt=input_ids,
519522
max_returned_tokens=max_returned_tokens,
523+
prompt_chunksize=prompt_chunksize,
520524
temperature=temperature,
521525
top_k=top_k,
522526
top_p=top_p,
@@ -536,6 +540,7 @@ def iterator():
536540
model=self.model,
537541
prompt=input_ids,
538542
max_returned_tokens=max_returned_tokens,
543+
prompt_chunksize=prompt_chunksize,
539544
temperature=temperature,
540545
top_k=top_k,
541546
top_p=top_p,

0 commit comments

Comments
 (0)