Skip to content

Add LoRA linear definition #11044

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: gh/lucylq/82/base
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/models/llama/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ runtime.python_library(
name = "llama_transformer",
srcs = [
"llama_transformer.py",
"lora.py",
"rope.py",
"attention.py",
"model_args.py",
Expand Down
54 changes: 45 additions & 9 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,28 @@ def update(

@register_attention("mha")
class AttentionMHA(Attention):
def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
def __init__(
self,
args: ModelArgs,
layer_id: int,
rope: Rope,
wq: Optional[nn.Module] = None,
wk: Optional[nn.Module] = None,
wv: Optional[nn.Module] = None,
wo: Optional[nn.Module] = None,
):
"""
Multi-head attention layer.

Args:
args (ModelArgs): Model configuration parameters.
layer_id (int): Layer index.
rope (Rope): Rotary position embedding module.
wq (Optional[nn.Module]): Query projection module. If None, use regular nn.Linear.
wk (Optional[nn.Module]): Key projection module. If None, use regular nn.Linear.
wv (Optional[nn.Module]): Value projection module. If None, use regular nn.Linear.
wo (Optional[nn.Module]): Output projection module. If None, use regular nn.Linear.
"""
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_heads = args.n_heads
Expand All @@ -349,19 +370,34 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)

self.wq = nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
self.wq = (
wq
if wq is not None
else nn.Linear(
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wk = (
wk
if wk is not None
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wk = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wv = (
wv
if wv is not None
else nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
)
)
self.wv = nn.Linear(
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
self.wo = (
wo
if wo is not None
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
)
self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

self.layer_id = layer_id

self.rope = rope

causal_mask = torch.tril(
Expand Down
79 changes: 78 additions & 1 deletion examples/models/llama/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
ForwardOptions,
)

from executorch.examples.models.llama.lora import LoRALinear
from executorch.examples.models.llama.model_args import ModelArgs
from executorch.examples.models.llama.norm import RMSNorm
from executorch.examples.models.llama.rope import Rope
Expand Down Expand Up @@ -254,7 +255,83 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
layers = torch.nn.ModuleList()
cls = ATTENTION_REGISTRY[model_args.attention_type]
for layer_id in range(model_args.n_layers):
attention = cls(model_args, layer_id, rope)
wq = (
LoRALinear(
in_dim=model_args.dim,
out_dim=model_args.n_heads * model_args.head_dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "q_proj" in model_args.target_modules
else (
torch.nn.Linear(
model_args.dim,
model_args.n_heads * model_args.head_dim,
bias=model_args.attention_qkv_bias,
)
)
)

wk = (
LoRALinear(
in_dim=model_args.dim,
out_dim=model_args.n_kv_heads * model_args.head_dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "k_proj" in model_args.target_modules
else (
torch.nn.Linear(
model_args.dim,
model_args.n_kv_heads * model_args.head_dim,
bias=model_args.attention_qkv_bias,
)
)
)
wv = (
LoRALinear(
in_dim=model_args.dim,
out_dim=model_args.n_kv_heads * model_args.head_dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "v_proj" in model_args.target_modules
else (
torch.nn.Linear(
model_args.dim,
model_args.n_kv_heads * model_args.head_dim,
bias=model_args.attention_qkv_bias,
)
)
)

wo = (
LoRALinear(
in_dim=model_args.n_kv_heads * model_args.head_dim,
out_dim=model_args.dim,
rank=model_args.r,
alpha=model_args.lora_alpha,
dropout=0.0,
use_bias=model_args.attention_qkv_bias,
)
if model_args.target_modules is not None
and "output_proj" in model_args.target_modules
else (
torch.nn.Linear(
model_args.n_heads * model_args.head_dim, model_args.dim, bias=False
)
)
)
attention = cls(model_args, layer_id, rope, wq, wk, wv, wo)
transformer_block = TransformerBlock(model_args, attention)
layers.append(transformer_block)

Expand Down
48 changes: 48 additions & 0 deletions examples/models/llama/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from torch import nn


class LoRALinear(nn.Module):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`."""

def __init__(
self,
in_dim: int,
out_dim: int,
rank: int,
alpha: float,
dropout: float = 0.0,
use_bias: bool = False,
):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.rank = rank
self.alpha = alpha
self.use_bias = use_bias
self.dropout = dropout

linear = nn.Linear(in_dim, out_dim, bias=use_bias)
weight = linear.weight
bias = linear.bias if self.use_bias else None
self.register_parameter("weight", nn.Parameter(weight))
self.register_parameter(
"bias", nn.Parameter(bias) if bias is not None else None
)

self.dropout = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()
self.lora_a = nn.Linear(in_features=in_dim, out_features=rank, bias=False)
self.lora_b = nn.Linear(in_features=rank, out_features=out_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.nn.functional.linear(x, self.weight, self.bias)
lora_out = self.lora_a(self.dropout(x))
lora_out = (self.alpha / self.rank) * self.lora_b(lora_out)

return out + lora_out
10 changes: 10 additions & 0 deletions examples/models/llama/model_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,18 @@ class ModelArgs:
eos_count: int = 2

quantization_args: Optional[dict] = None
# LoRA for QAT.
lora_args: Optional[dict] = None

# LoRA arguments to set up a LoRA inference model.
# These arguments come directly from a torchtune LoRA config.
r: Optional[int] = None # Rank.
lora_alpha: Optional[int] = None # Alpha.
# Eg. q_proj, k_proj, v_proj, output_proj
target_modules: Optional[list] = None
peft_type: Optional[str] = None # PEFT type.
base_model_name_or_path: Optional[str] = None # Base model name or path.

def __post_init__(self):
if self.n_kv_heads is None:
self.n_kv_heads = self.n_heads
Expand Down
Loading