Skip to content

Commit 222ecf0

Browse files
authored
Upload LLaMA slides and code (#46)
* upload slides * upload code and slides
1 parent c89d492 commit 222ecf0

File tree

7 files changed

+171
-0
lines changed

7 files changed

+171
-0
lines changed
Binary file not shown.
Loading
Loading
Loading
Loading
Loading
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"<div align=center><img src=\"./assets/rotary_embedding.png\"></div>"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"from enum import Enum\n",
17+
"import numpy as np\n",
18+
"\n",
19+
"from mindspore.common.tensor import Tensor\n",
20+
"from mindspore.common.parameter import Parameter\n",
21+
"from mindspore import nn\n",
22+
"import mindspore.common.dtype as mstype\n",
23+
"from mindspore.ops import operations as P\n",
24+
"from mindspore.ops import functional as F\n",
25+
"from mindspore.nn.cell import Cell\n",
26+
"\n",
27+
"\n",
28+
"def precompute_freqs_cis(\n",
29+
" dim: int,\n",
30+
" end: int,\n",
31+
" theta: float = 10000.0,\n",
32+
" dtype=mstype.float32,\n",
33+
" pretrain_seqlen=2048,\n",
34+
" extend_method=SeqExtendMethod.NONE.value):\n",
35+
" \"\"\"\n",
36+
" Precompute of freqs and mask for rotary embedding.\n",
37+
" \"\"\"\n",
38+
" ratio = 1.\n",
39+
" if extend_method != SeqExtendMethod.NONE.value and end > pretrain_seqlen:\n",
40+
" ratio = end / pretrain_seqlen\n",
41+
" if extend_method == SeqExtendMethod.NTK.value:\n",
42+
" theta *= ratio\n",
43+
"\n",
44+
" # 2i/d\n",
45+
" # dim = 64\n",
46+
" # 2i: np.arange(0, dim, 2) ==> [0, 2, 4, ..., 64], tot_num = 32\n",
47+
" # 2i/d = np.arange(0, dim, 2)[: (dim // 2)], dim // 2 = tot_num = 32\n",
48+
" freqs_base = np.arange(0, dim, 2)[: (dim // 2)].astype(np.float32) # (head_dim // 2, )\n",
49+
"\n",
50+
" # theta**(-2i/d) = 1/theta**(2i/d)\n",
51+
" # (dim//2, ) => (32,)\n",
52+
" freqs = 1.0 / (theta ** (freqs_base / dim)) # (head_dim // 2, )\n",
53+
"\n",
54+
" # t ==> m\n",
55+
" # t = [0, 1, 2, 3, ..., 1024]\n",
56+
" if extend_method == SeqExtendMethod.PI.value:\n",
57+
" t = np.arange(0, end / ratio, 1 / ratio).astype(np.float32)\n",
58+
" else:\n",
59+
" t = np.arange(0, end, 1).astype(np.float32) # type: ignore # (seq_len,)\n",
60+
" # (1024, )(32, ) ==> (1024, 32) m*theta_i\n",
61+
" freqs = np.outer(t, freqs) # type: ignore (seq_len, head_dim // 2)\n",
62+
" emb = np.concatenate((freqs, freqs), axis=-1)\n",
63+
"\n",
64+
" freqs_cos = np.cos(emb) # (seq_len, head_dim)\n",
65+
" freqs_sin = np.sin(emb) # (seq_len, head_dim)\n",
66+
" freqs_cos = Tensor(freqs_cos, dtype=dtype)\n",
67+
" freqs_sin = Tensor(freqs_sin, dtype=dtype)\n",
68+
"\n",
69+
" swap_mask = get_swap_mask(dim)\n",
70+
" swap_mask = Tensor(swap_mask, dtype=dtype)\n",
71+
"\n",
72+
" # sin(m * theta_i)\n",
73+
" # cos(m * theta_i)\n",
74+
" return freqs_cos, freqs_sin, swap_mask"
75+
]
76+
},
77+
{
78+
"cell_type": "markdown",
79+
"metadata": {},
80+
"source": [
81+
"在2D vector情况下,旋转的矩阵表达应为:\n",
82+
"\n",
83+
"<div align=center><img src=\"./assets/rotation-2d.png\"></div>\n",
84+
"\n",
85+
"拓展到general form,即当模型的hidden size长度大于2时:\n",
86+
"\n",
87+
"<div align=center><img src=\"./assets/rotation-general.png\"></div>\n",
88+
"\n",
89+
"将旋转变化作用在q、k之上,然后q、k再进行点积的结果就变为如下公式,即该点积只和两个向量之间的相对位置有关:\n",
90+
"\n",
91+
"<div align=center><img src=\"./assets/formula.png\"></div>\n",
92+
"\n",
93+
"但因为矩阵的稀疏性,直接用矩阵乘法来实现会很浪费算力,实际情况下一般会通过下述方式来实现RoPE:\n",
94+
"\n",
95+
"<div align=center><img src=\"./assets/rope-calculation.png\"></div>"
96+
]
97+
},
98+
{
99+
"cell_type": "code",
100+
"execution_count": null,
101+
"metadata": {},
102+
"outputs": [],
103+
"source": [
104+
"class LlamaRotaryEmbedding(Cell):\n",
105+
" r\"\"\"\n",
106+
" Rotary Position Embedding.\n",
107+
"\n",
108+
" Args:\n",
109+
" - **head_dim** (int): The dim of multi head attention.\n",
110+
" - **compute_dtype** (mstype): The compute type, default mstype.float16.\n",
111+
" - **parallel_config** (dict): - Parallel Config.\n",
112+
" Inputs:\n",
113+
" - **x** (Tensor) - Tensor of shape :math:`(batch, seq\\_length, hidden\\_size)`.\n",
114+
"\n",
115+
" Outputs:\n",
116+
" Tensor of shape :math:`(batch, seq_length, hidden_size)`.\n",
117+
" \"\"\"\n",
118+
"\n",
119+
" def __init__(self, head_dim=128, compute_dtype=mstype.float32):\n",
120+
" super().__init__(auto_prefix=False)\n",
121+
" self.head_dim = head_dim\n",
122+
" self.dtype = compute_dtype\n",
123+
"\n",
124+
" self.add = P.Add()\n",
125+
" self.bmm_swap = P.BatchMatMul()\n",
126+
" self.mul = P.Mul()\n",
127+
"\n",
128+
" self.cast = P.Cast()\n",
129+
"\n",
130+
" def rotate_half(self, x, swap_mask):\n",
131+
" # [bs, n_head/n_kv_head, seq/1, head_dim], [head_dim, head_dim]\n",
132+
" x = self.bmm_swap(x, swap_mask)\n",
133+
" return x\n",
134+
"\n",
135+
" def construct(self, xq: Tensor, xk: Tensor, freqs_cis):\n",
136+
" \"\"\"Forward of rotary position embedding.\"\"\"\n",
137+
" original_type = xq.dtype\n",
138+
" xq = self.cast(xq, self.dtype)\n",
139+
" xk = self.cast(xk, self.dtype)\n",
140+
" # xq, xk: [bs, n_head/n_kv_head, seq/1, head_dim]\n",
141+
" freqs_cos, freqs_sin, swap_mask = freqs_cis\n",
142+
" xq_out = self.add(self.mul(xq, freqs_cos),\n",
143+
" self.mul(self.rotate_half(xq, swap_mask), freqs_sin))\n",
144+
" xk_out = self.add(self.mul(xk, freqs_cos),\n",
145+
" self.mul(self.rotate_half(xk, swap_mask), freqs_sin))\n",
146+
"\n",
147+
" xq_out = self.cast(xq_out, original_type)\n",
148+
" xk_out = self.cast(xk_out, original_type)\n",
149+
" return xq_out, xk_out\n",
150+
"\n",
151+
" def shard(self, strategy_in):\n",
152+
" self.add.shard((strategy_in, strategy_in))\n",
153+
" self.bmm_swap.shard((strategy_in, (1, 1)))\n",
154+
" self.mul.shard((strategy_in, (strategy_in[0], 1, 1, 1)))\n"
155+
]
156+
}
157+
],
158+
"metadata": {
159+
"kernelspec": {
160+
"display_name": "mindspore_2.2",
161+
"language": "python",
162+
"name": "python3"
163+
},
164+
"language_info": {
165+
"name": "python",
166+
"version": "3.7.16"
167+
}
168+
},
169+
"nbformat": 4,
170+
"nbformat_minor": 2
171+
}

0 commit comments

Comments
 (0)