|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "<div align=center><img src=\"./assets/rotary_embedding.png\"></div>" |
| 8 | + ] |
| 9 | + }, |
| 10 | + { |
| 11 | + "cell_type": "code", |
| 12 | + "execution_count": null, |
| 13 | + "metadata": {}, |
| 14 | + "outputs": [], |
| 15 | + "source": [ |
| 16 | + "from enum import Enum\n", |
| 17 | + "import numpy as np\n", |
| 18 | + "\n", |
| 19 | + "from mindspore.common.tensor import Tensor\n", |
| 20 | + "from mindspore.common.parameter import Parameter\n", |
| 21 | + "from mindspore import nn\n", |
| 22 | + "import mindspore.common.dtype as mstype\n", |
| 23 | + "from mindspore.ops import operations as P\n", |
| 24 | + "from mindspore.ops import functional as F\n", |
| 25 | + "from mindspore.nn.cell import Cell\n", |
| 26 | + "\n", |
| 27 | + "\n", |
| 28 | + "def precompute_freqs_cis(\n", |
| 29 | + " dim: int,\n", |
| 30 | + " end: int,\n", |
| 31 | + " theta: float = 10000.0,\n", |
| 32 | + " dtype=mstype.float32,\n", |
| 33 | + " pretrain_seqlen=2048,\n", |
| 34 | + " extend_method=SeqExtendMethod.NONE.value):\n", |
| 35 | + " \"\"\"\n", |
| 36 | + " Precompute of freqs and mask for rotary embedding.\n", |
| 37 | + " \"\"\"\n", |
| 38 | + " ratio = 1.\n", |
| 39 | + " if extend_method != SeqExtendMethod.NONE.value and end > pretrain_seqlen:\n", |
| 40 | + " ratio = end / pretrain_seqlen\n", |
| 41 | + " if extend_method == SeqExtendMethod.NTK.value:\n", |
| 42 | + " theta *= ratio\n", |
| 43 | + "\n", |
| 44 | + " # 2i/d\n", |
| 45 | + " # dim = 64\n", |
| 46 | + " # 2i: np.arange(0, dim, 2) ==> [0, 2, 4, ..., 64], tot_num = 32\n", |
| 47 | + " # 2i/d = np.arange(0, dim, 2)[: (dim // 2)], dim // 2 = tot_num = 32\n", |
| 48 | + " freqs_base = np.arange(0, dim, 2)[: (dim // 2)].astype(np.float32) # (head_dim // 2, )\n", |
| 49 | + "\n", |
| 50 | + " # theta**(-2i/d) = 1/theta**(2i/d)\n", |
| 51 | + " # (dim//2, ) => (32,)\n", |
| 52 | + " freqs = 1.0 / (theta ** (freqs_base / dim)) # (head_dim // 2, )\n", |
| 53 | + "\n", |
| 54 | + " # t ==> m\n", |
| 55 | + " # t = [0, 1, 2, 3, ..., 1024]\n", |
| 56 | + " if extend_method == SeqExtendMethod.PI.value:\n", |
| 57 | + " t = np.arange(0, end / ratio, 1 / ratio).astype(np.float32)\n", |
| 58 | + " else:\n", |
| 59 | + " t = np.arange(0, end, 1).astype(np.float32) # type: ignore # (seq_len,)\n", |
| 60 | + " # (1024, )(32, ) ==> (1024, 32) m*theta_i\n", |
| 61 | + " freqs = np.outer(t, freqs) # type: ignore (seq_len, head_dim // 2)\n", |
| 62 | + " emb = np.concatenate((freqs, freqs), axis=-1)\n", |
| 63 | + "\n", |
| 64 | + " freqs_cos = np.cos(emb) # (seq_len, head_dim)\n", |
| 65 | + " freqs_sin = np.sin(emb) # (seq_len, head_dim)\n", |
| 66 | + " freqs_cos = Tensor(freqs_cos, dtype=dtype)\n", |
| 67 | + " freqs_sin = Tensor(freqs_sin, dtype=dtype)\n", |
| 68 | + "\n", |
| 69 | + " swap_mask = get_swap_mask(dim)\n", |
| 70 | + " swap_mask = Tensor(swap_mask, dtype=dtype)\n", |
| 71 | + "\n", |
| 72 | + " # sin(m * theta_i)\n", |
| 73 | + " # cos(m * theta_i)\n", |
| 74 | + " return freqs_cos, freqs_sin, swap_mask" |
| 75 | + ] |
| 76 | + }, |
| 77 | + { |
| 78 | + "cell_type": "markdown", |
| 79 | + "metadata": {}, |
| 80 | + "source": [ |
| 81 | + "在2D vector情况下,旋转的矩阵表达应为:\n", |
| 82 | + "\n", |
| 83 | + "<div align=center><img src=\"./assets/rotation-2d.png\"></div>\n", |
| 84 | + "\n", |
| 85 | + "拓展到general form,即当模型的hidden size长度大于2时:\n", |
| 86 | + "\n", |
| 87 | + "<div align=center><img src=\"./assets/rotation-general.png\"></div>\n", |
| 88 | + "\n", |
| 89 | + "将旋转变化作用在q、k之上,然后q、k再进行点积的结果就变为如下公式,即该点积只和两个向量之间的相对位置有关:\n", |
| 90 | + "\n", |
| 91 | + "<div align=center><img src=\"./assets/formula.png\"></div>\n", |
| 92 | + "\n", |
| 93 | + "但因为矩阵的稀疏性,直接用矩阵乘法来实现会很浪费算力,实际情况下一般会通过下述方式来实现RoPE:\n", |
| 94 | + "\n", |
| 95 | + "<div align=center><img src=\"./assets/rope-calculation.png\"></div>" |
| 96 | + ] |
| 97 | + }, |
| 98 | + { |
| 99 | + "cell_type": "code", |
| 100 | + "execution_count": null, |
| 101 | + "metadata": {}, |
| 102 | + "outputs": [], |
| 103 | + "source": [ |
| 104 | + "class LlamaRotaryEmbedding(Cell):\n", |
| 105 | + " r\"\"\"\n", |
| 106 | + " Rotary Position Embedding.\n", |
| 107 | + "\n", |
| 108 | + " Args:\n", |
| 109 | + " - **head_dim** (int): The dim of multi head attention.\n", |
| 110 | + " - **compute_dtype** (mstype): The compute type, default mstype.float16.\n", |
| 111 | + " - **parallel_config** (dict): - Parallel Config.\n", |
| 112 | + " Inputs:\n", |
| 113 | + " - **x** (Tensor) - Tensor of shape :math:`(batch, seq\\_length, hidden\\_size)`.\n", |
| 114 | + "\n", |
| 115 | + " Outputs:\n", |
| 116 | + " Tensor of shape :math:`(batch, seq_length, hidden_size)`.\n", |
| 117 | + " \"\"\"\n", |
| 118 | + "\n", |
| 119 | + " def __init__(self, head_dim=128, compute_dtype=mstype.float32):\n", |
| 120 | + " super().__init__(auto_prefix=False)\n", |
| 121 | + " self.head_dim = head_dim\n", |
| 122 | + " self.dtype = compute_dtype\n", |
| 123 | + "\n", |
| 124 | + " self.add = P.Add()\n", |
| 125 | + " self.bmm_swap = P.BatchMatMul()\n", |
| 126 | + " self.mul = P.Mul()\n", |
| 127 | + "\n", |
| 128 | + " self.cast = P.Cast()\n", |
| 129 | + "\n", |
| 130 | + " def rotate_half(self, x, swap_mask):\n", |
| 131 | + " # [bs, n_head/n_kv_head, seq/1, head_dim], [head_dim, head_dim]\n", |
| 132 | + " x = self.bmm_swap(x, swap_mask)\n", |
| 133 | + " return x\n", |
| 134 | + "\n", |
| 135 | + " def construct(self, xq: Tensor, xk: Tensor, freqs_cis):\n", |
| 136 | + " \"\"\"Forward of rotary position embedding.\"\"\"\n", |
| 137 | + " original_type = xq.dtype\n", |
| 138 | + " xq = self.cast(xq, self.dtype)\n", |
| 139 | + " xk = self.cast(xk, self.dtype)\n", |
| 140 | + " # xq, xk: [bs, n_head/n_kv_head, seq/1, head_dim]\n", |
| 141 | + " freqs_cos, freqs_sin, swap_mask = freqs_cis\n", |
| 142 | + " xq_out = self.add(self.mul(xq, freqs_cos),\n", |
| 143 | + " self.mul(self.rotate_half(xq, swap_mask), freqs_sin))\n", |
| 144 | + " xk_out = self.add(self.mul(xk, freqs_cos),\n", |
| 145 | + " self.mul(self.rotate_half(xk, swap_mask), freqs_sin))\n", |
| 146 | + "\n", |
| 147 | + " xq_out = self.cast(xq_out, original_type)\n", |
| 148 | + " xk_out = self.cast(xk_out, original_type)\n", |
| 149 | + " return xq_out, xk_out\n", |
| 150 | + "\n", |
| 151 | + " def shard(self, strategy_in):\n", |
| 152 | + " self.add.shard((strategy_in, strategy_in))\n", |
| 153 | + " self.bmm_swap.shard((strategy_in, (1, 1)))\n", |
| 154 | + " self.mul.shard((strategy_in, (strategy_in[0], 1, 1, 1)))\n" |
| 155 | + ] |
| 156 | + } |
| 157 | + ], |
| 158 | + "metadata": { |
| 159 | + "kernelspec": { |
| 160 | + "display_name": "mindspore_2.2", |
| 161 | + "language": "python", |
| 162 | + "name": "python3" |
| 163 | + }, |
| 164 | + "language_info": { |
| 165 | + "name": "python", |
| 166 | + "version": "3.7.16" |
| 167 | + } |
| 168 | + }, |
| 169 | + "nbformat": 4, |
| 170 | + "nbformat_minor": 2 |
| 171 | +} |
0 commit comments