Skip to content

Commit 3847686

Browse files
committed
Simplify KV cache assignments
1 parent d2e9e45 commit 3847686

File tree

6 files changed

+105
-140
lines changed

6 files changed

+105
-140
lines changed

Diff for: litgpt/adapter.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
"""
1010

1111
from dataclasses import dataclass
12-
from typing import Any, Dict, Optional, Tuple, List
12+
from typing import Any, Dict, Optional, Tuple
1313

1414
import torch
1515
import torch.nn as nn
@@ -30,39 +30,28 @@ class Config(BaseConfig):
3030

3131
class GPT(BaseModel):
3232
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
33-
def __init__(
34-
self,
35-
config: Config,
36-
kv_cache: Optional[List[KVCache]] = None
33+
def __init__(self, config: Config,
3734
) -> None:
3835
nn.Module.__init__(self)
3936
assert config.padded_vocab_size is not None
4037
self.config = config
4138

42-
if kv_cache is not None:
43-
if len(kv_cache) != config.n_layer:
44-
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
45-
for kvc in kv_cache:
46-
self._check_kv_cache(config, kvc)
47-
self._default_kv_cache = False
48-
else:
49-
kv_cache = [None] * config.n_layer
50-
self._default_kv_cache = True
5139
self.lm_head = nn.Linear(
5240
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
5341
)
5442
self.transformer = nn.ModuleDict(
5543
dict(
5644
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
5745
h=nn.ModuleList(
58-
Block(config, block_idx, kv_cache=kvc)
59-
for block_idx, kvc in enumerate(kv_cache)
46+
Block(config, block_idx)
47+
for block_idx in range(config.n_layer)
6048
),
6149
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
6250
)
6351
)
6452
self.mask_cache: Optional[torch.Tensor] = None
6553
self.max_seq_length = self.config.block_size
54+
self._default_kv_cache = False
6655

6756
@classmethod
6857
def from_name(cls, name: str, **kwargs: Any) -> Self:

Diff for: litgpt/adapter_v2.py

+4-16
Original file line numberDiff line numberDiff line change
@@ -65,39 +65,27 @@ def reset_parameters(self) -> None:
6565

6666
class GPT(BaseModel):
6767
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
68-
def __init__(
69-
self,
70-
config: Config,
71-
kv_cache: Optional[List[KVCache]] = None
72-
) -> None:
68+
def __init__(self, config: Config) -> None:
7369
nn.Module.__init__(self)
7470
assert config.padded_vocab_size is not None
7571
self.config = config
7672

77-
if kv_cache is not None:
78-
if len(kv_cache) != config.n_layer:
79-
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
80-
for kvc in kv_cache:
81-
self._check_kv_cache(config, kvc)
82-
self._default_kv_cache = False
83-
else:
84-
kv_cache = [None] * config.n_layer
85-
self._default_kv_cache = True
8673
self.lm_head = AdapterV2Linear(
8774
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
8875
)
8976
self.transformer = nn.ModuleDict(
9077
dict(
9178
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
9279
h=nn.ModuleList(
93-
Block(config, block_idx, kv_cache=kvc)
94-
for block_idx, kvc in enumerate(kv_cache)
80+
Block(config, block_idx)
81+
for block_idx in range(config.n_layer)
9582
),
9683
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
9784
)
9885
)
9986
self.mask_cache: Optional[torch.Tensor] = None
10087
self.max_seq_length = self.config.block_size
88+
self._default_kv_cache = False
10189

10290
@classmethod
10391
def from_name(cls, name: str, **kwargs: Any) -> Self:

Diff for: litgpt/generate/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def main(
594594
temperature=temperature,
595595
top_k=top_k,
596596
top_p=top_p,
597-
eos_id=int(tokenizer.eos_id),
597+
eos_id=tokenizer.eos_id,
598598
)[0]
599599
t = time.perf_counter() - t0
600600
fabric.print(tokenizer.decode(y))

Diff for: litgpt/lora.py

+3-16
Original file line numberDiff line numberDiff line change
@@ -482,24 +482,11 @@ def mlp_class(self) -> Type:
482482

483483
class GPT(BaseModel):
484484
# Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here.
485-
def __init__(
486-
self,
487-
config: Config,
488-
kv_cache: Optional[List[KVCache]] = None
489-
) -> None:
485+
def __init__(self, config: Config) -> None:
490486
nn.Module.__init__(self)
491487
assert config.padded_vocab_size is not None
492488
self.config = config
493489

494-
if kv_cache is not None:
495-
if len(kv_cache) != config.n_layer:
496-
raise ValueError(f"kv_cache length {len(kv_cache)} != {config.n_layer} = config.n_layer")
497-
for kvc in kv_cache:
498-
self._check_kv_cache(config, kvc)
499-
self._default_kv_cache = False
500-
else:
501-
kv_cache = [None] * config.n_layer
502-
self._default_kv_cache = True
503490
self.lm_head = create_lora_linear(
504491
config,
505492
config.n_embd,
@@ -511,8 +498,8 @@ def __init__(
511498
dict(
512499
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
513500
h=nn.ModuleList(
514-
Block(config, block_idx, kv_cache=kvc)
515-
for block_idx, kvc in enumerate(kv_cache)
501+
Block(config, block_idx)
502+
for block_idx in range(config.n_layer)
516503
),
517504
ln_f=config.norm_class(config.n_embd, eps=config.norm_eps),
518505
)

0 commit comments

Comments
 (0)