Skip to content

Commit 788f67e

Browse files
Add LoRA
1 parent b2a519b commit 788f67e

File tree

7 files changed

+850
-83
lines changed

7 files changed

+850
-83
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,6 @@ Instructions to run the image can be found in the [official documentation](https
233233
- Use Mistral models on [Mistral AI official API](https://console.mistral.ai/) (La Plateforme)
234234
- Use Mistral models via [cloud providers](https://docs.mistral.ai/deployment/cloud/overview/)
235235

236+
## References
237+
238+
[1]: [LoRA](https://arxiv.org/abs/2106.09685): Low-Rank Adaptation of Large Language Models, Hu et al. 2021

poetry.lock

+646-71
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "mistral_inference"
3-
version = "v1.0.4"
3+
version = "v1.1.0"
44
description = ""
55
authors = ["bam4d <[email protected]>"]
66
readme = "README.md"
@@ -24,7 +24,7 @@ exclude = ["docs", "tools", "build"]
2424

2525
[tool.poetry.dependencies]
2626
python = "^3.9.10"
27-
xformers = ">=0.0.25"
27+
xformers = ">=0.0.24"
2828
simple-parsing = ">=0.1.5"
2929
fire = ">=0.6.0"
3030
mistral_common = "^1.0.0"

src/mistral_inference/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.0.4"
1+
__version__ = "1.1.0"

src/mistral_inference/lora.py

+166
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import logging
2+
from dataclasses import dataclass
3+
from pathlib import Path
4+
from typing import Dict, NamedTuple, Union
5+
6+
import safetensors.torch
7+
import torch
8+
import torch.nn as nn
9+
from simple_parsing.helpers import Serializable
10+
11+
12+
@dataclass
13+
class LoraArgs(Serializable):
14+
rank: int
15+
scaling: float
16+
17+
def __post_init__(self):
18+
assert self.rank > 0
19+
assert self.scaling > 0.0
20+
21+
22+
class LoRALinear(nn.Module):
23+
"""
24+
Implementation of:
25+
- LoRA: https://arxiv.org/abs/2106.09685
26+
27+
Notes:
28+
- Freezing is handled at network level, not layer level.
29+
- Scaling factor controls relative importance of LoRA skip
30+
connection versus original frozen weight. General guidance is
31+
to keep it to 2.0 and sweep over learning rate when changing
32+
the rank.
33+
"""
34+
35+
def __init__(
36+
self,
37+
in_features: int,
38+
out_features: int,
39+
rank: int,
40+
scaling: float,
41+
bias: bool = False,
42+
):
43+
super().__init__()
44+
45+
self.in_features = in_features
46+
self.out_features = out_features
47+
assert not bias
48+
self.bias = bias
49+
self.rank = rank
50+
self.scaling = scaling
51+
52+
self.lora_A = nn.Linear(
53+
self.in_features,
54+
self.rank,
55+
bias=self.bias,
56+
)
57+
self.lora_B = nn.Linear(
58+
self.rank,
59+
self.out_features,
60+
bias=self.bias,
61+
)
62+
63+
self.linear = nn.Linear(self.in_features, self.out_features, bias=self.bias)
64+
65+
# make sure no LoRA weights are marked as "missing" in load_state_dict
66+
def ignore_missing_keys(m: nn.Module, incompatible_keys: NamedTuple):
67+
incompatible_keys.missing_keys[:] = [] # type: ignore
68+
69+
self.register_load_state_dict_post_hook(ignore_missing_keys)
70+
71+
def forward(self, x: torch.Tensor):
72+
lora = self.lora_B(self.lora_A(x))
73+
return self.linear(x) + lora * self.scaling
74+
75+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
76+
key_name = prefix + "weight"
77+
78+
# full checkpoint
79+
if key_name in state_dict:
80+
w_ref = state_dict[key_name]
81+
82+
# load frozen weights
83+
state_dict = {
84+
"linear.weight": w_ref,
85+
"lora_A.weight": torch.zeros_like(
86+
self.lora_A.weight, device=w_ref.device, dtype=w_ref.dtype
87+
),
88+
"lora_B.weight": torch.zeros_like(
89+
self.lora_B.weight, device=w_ref.device, dtype=w_ref.dtype
90+
),
91+
}
92+
self.load_state_dict(state_dict, assign=True, strict=True)
93+
94+
95+
class LoRALoaderMixin:
96+
def load_lora(self, lora_path: Union[Path, str], scaling: float = 2.0):
97+
"""Loads LoRA checkpoint"""
98+
99+
lora_path = Path(lora_path)
100+
assert lora_path.is_file(), f"{lora_path} does not exist or is not a file"
101+
102+
state_dict = safetensors.torch.load_file(lora_path)
103+
104+
self._load_lora_state_dict(state_dict, scaling=scaling)
105+
106+
def _load_lora_state_dict(
107+
self, lora_state_dict: Dict[str, torch.Tensor], scaling: float = 2.0
108+
):
109+
"""Loads LoRA state_dict"""
110+
111+
lora_dtypes = set([p.dtype for p in lora_state_dict.values()])
112+
assert (
113+
len(lora_dtypes) == 1
114+
), f"LoRA weights have multipe different dtypes {lora_dtypes}. All weights need to have the same dtype"
115+
lora_dtype = lora_dtypes.pop()
116+
assert (
117+
lora_dtype == self.dtype
118+
), f"LoRA weights dtype differs from model's dtype {lora_dtype} != {self.dtype}"
119+
assert all("lora" in key for key in lora_state_dict.keys())
120+
121+
# move tensors to device
122+
lora_state_dict = {k: v.to(self.device) for k, v in lora_state_dict.items()}
123+
124+
state_dict = self.state_dict()
125+
126+
if self.args.lora is None:
127+
logging.info("Loading and merging LoRA weights...")
128+
129+
# replace every nn.Linear with a LoRALinear with 'meta' device except the output layer
130+
named_modules = dict(self.named_modules())
131+
for name, module in named_modules.items():
132+
if isinstance(module, nn.Linear) and name != "output":
133+
layer_id = name.split(".")[1]
134+
if layer_id not in self.layers:
135+
logging.debug(
136+
"Skipping parameter %s at pipeline rank %d",
137+
name,
138+
self.pipeline_rank,
139+
)
140+
else:
141+
weight = (
142+
module.weight
143+
+ (
144+
lora_state_dict[name + ".lora_B.weight"]
145+
@ lora_state_dict[name + ".lora_A.weight"]
146+
)
147+
* scaling
148+
)
149+
150+
state_dict[name + ".weight"] = weight
151+
else:
152+
logging.info("Loading LoRA weights...")
153+
for k, v in lora_state_dict.items():
154+
state_dict.update(lora_state_dict)
155+
156+
layer_id = k.split(".")[1]
157+
if layer_id in self.layers:
158+
state_dict[k] = v
159+
else:
160+
logging.debug(
161+
"Skipping parameter %s at pipeline rank %d",
162+
k,
163+
self.pipeline_rank,
164+
)
165+
166+
self.load_state_dict(state_dict, strict=True)

src/mistral_inference/main.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33
from pathlib import Path
4-
from typing import List
4+
from typing import List, Optional
55

66
import fire # type: ignore
77
import torch
@@ -46,6 +46,7 @@ def interactive(
4646
temperature: float = 0.7,
4747
num_pipeline_ranks: int = 1,
4848
instruct: bool = False,
49+
lora_path: Optional[str] = None,
4950
) -> None:
5051
if is_torchrun():
5152
torch.distributed.init_process_group()
@@ -64,6 +65,10 @@ def interactive(
6465
Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
6566
)
6667

68+
# load LoRA
69+
if lora_path is not None:
70+
transformer.load_lora(Path(lora_path))
71+
6772
prompt: str = ""
6873
messages: List[UserMessage | AssistantMessage] = []
6974

@@ -117,6 +122,7 @@ def demo(
117122
model_path: str,
118123
max_tokens: int = 35,
119124
temperature: float = 0,
125+
lora_path: Optional[str] = None,
120126
) -> None:
121127
if is_torchrun():
122128
torch.distributed.init_process_group()
@@ -131,6 +137,10 @@ def demo(
131137
transformer = Transformer.from_folder(
132138
Path(model_path), max_batch_size=3, num_pipeline_ranks=num_pipeline_ranks
133139
)
140+
# load LoRA
141+
if lora_path is not None:
142+
transformer.load_lora(Path(lora_path))
143+
134144
mistral_tokenizer: MistralTokenizer = load_tokenizer(Path(model_path))
135145
tokenizer: Tokenizer = mistral_tokenizer.instruct_tokenizer.tokenizer
136146

src/mistral_inference/model.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import math
44
from dataclasses import dataclass
5+
from functools import partial
56
from pathlib import Path
67
from typing import Any, List, Mapping, Optional, Tuple, Union
78

@@ -16,6 +17,7 @@
1617
CacheInputMetadata,
1718
CacheView,
1819
)
20+
from mistral_inference.lora import LoraArgs, LoRALinear, LoRALoaderMixin
1921
from mistral_inference.moe import MoeArgs, MoeLayer
2022
from mistral_inference.rope import apply_rotary_emb, precompute_freqs_cis
2123

@@ -37,6 +39,8 @@ class ModelArgs(Serializable):
3739
rope_theta: Optional[float] = None
3840
# If this is set, we will use MoE layers instead of dense layers.
3941
moe: Optional[MoeArgs] = None
42+
# If this is set, we will load LoRA linear layers instead of linear layers.
43+
lora: Optional[LoraArgs] = None
4044

4145

4246
@dataclass
@@ -61,6 +65,13 @@ def repeat_kv(
6165
return keys, values
6266

6367

68+
def maybe_lora(args: ModelArgs) -> Union[nn.Linear, LoRALinear]:
69+
if args.lora is None:
70+
return nn.Linear
71+
else:
72+
return partial(LoRALinear, rank=args.lora.rank, scaling=args.lora.scaling)
73+
74+
6475
class Attention(nn.Module):
6576
def __init__(self, args: ModelArgs):
6677
super().__init__()
@@ -74,10 +85,11 @@ def __init__(self, args: ModelArgs):
7485

7586
self.scale = self.args.head_dim**-0.5
7687

77-
self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
78-
self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
79-
self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
80-
self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
88+
MaybeLora = maybe_lora(args)
89+
self.wq = MaybeLora(args.dim, args.n_heads * args.head_dim, bias=False)
90+
self.wk = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
91+
self.wv = MaybeLora(args.dim, args.n_kv_heads * args.head_dim, bias=False)
92+
self.wo = MaybeLora(args.n_heads * args.head_dim, args.dim, bias=False)
8193

8294
def forward(
8395
self,
@@ -127,9 +139,10 @@ class FeedForward(nn.Module):
127139
def __init__(self, args: ModelArgs):
128140
super().__init__()
129141

130-
self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
131-
self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
132-
self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)
142+
MaybeLora = maybe_lora(args)
143+
self.w1 = MaybeLora(args.dim, args.hidden_dim, bias=False)
144+
self.w2 = MaybeLora(args.hidden_dim, args.dim, bias=False)
145+
self.w3 = MaybeLora(args.dim, args.hidden_dim, bias=False)
133146

134147
def forward(self, x: torch.Tensor) -> torch.Tensor:
135148
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore
@@ -179,7 +192,7 @@ def forward(
179192
return out
180193

181194

182-
class Transformer(nn.Module):
195+
class Transformer(nn.Module, LoRALoaderMixin):
183196
def __init__(
184197
self,
185198
args: ModelArgs,

0 commit comments

Comments
 (0)