|
1 | 1 | ---
|
2 | 2 | layout: post
|
3 |
| -title: "Sorting-Free GPU-Kernels for Categorical Sampling Under Filtering in LLM Inference" |
| 3 | +title: "Sorting-Free GPU Kernels for LLM Sampling" |
4 | 4 | date: 2025-03-10
|
5 | 5 | comments: true
|
6 | 6 | author: Shanli Xing (UW), Zihao Ye (UW), Bohan Hou (CMU), Luis Ceze (UW), Tianqi Chen (CMU)
|
@@ -134,8 +134,22 @@ While this rejection sampling approach is simple and efficient in most cases, it
|
134 | 134 |
|
135 | 135 | To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.3), we introduce the a new algorithm called **Dual Pivot Rejection Sampling**, which uses two pivots for faster convergence in rejection sampling. The algorithm is as follows:
|
136 | 136 |
|
| 137 | +1. Let $f$ be a function that checks if a probability value is valid: $f(x)=1$ if valid, $0$ if not. |
| 138 | +2. Initialize $\textrm{low} \leftarrow 0$ and $\textrm{high} \leftarrow \max_i(p_i)$ as the initial range, it's guaranteed that $f(\textrm{low})=0$ and $f(\textrm{high})=1$. |
| 139 | +3. Sample over probability values in the range $(\textrm{low}, \infty)$ using inverse transform sampling. |
| 140 | +4. Suppose $j$ is the sampled token, let $\textrm{pivot}_1\leftarrow p_j$, and $\textrm{pivot}_2\leftarrow \frac{\textrm{pivot}_1+\textrm{high}}{2}$. |
| 141 | + 1. If $f(\textrm{pivot}_1)=1$, we accept the sampled token and return $j$. |
| 142 | + 2. If $f(\textrm{pivot}_1)=0$, $f(\textrm{pivot}_2)=1$, we set $\textrm{pivot}_1$ as new $\textrm{low}$ and $\textrm{pivot}_2$ as new $\textrm{high}$. |
| 143 | + 3. If $f(\textrm{pivot}_1)=0$, $f(\textrm{pivot}_2)=0$, we set $\textrm{pivot}_2$ as new $\textrm{low}$. |
| 144 | +5. Repeat step 3 and 4 until success. |
137 | 145 |
|
| 146 | +<p align="center"> |
| 147 | +<img src="/assets/imgs/sampling_blog/dual-pivot-sampling.png" alt="Dual Pivot Rejection Sampling" width="800"/> |
| 148 | +<br> |
| 149 | +<span style="color: gray; font-style: italic;">Figure 4: Transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, we either accept the sampled token (case 1) or shrinking the range by at least half (case 2 and 3).</span> |
| 150 | +</p> |
138 | 151 |
|
| 152 | +Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, in each round, if the sampled token is accepted, we return the token, otherwise, the new range's extent is $\frac{\text{high}-\text{pivot}_1}{2} < \frac{\text{high}-\text{low}}{2}$, which is at least half of the previous range. Thus it's guaranteed that the number of rounds is $O(\log(1/\epsilon))$ where $\epsilon$ is the minimal possible value in floating point representation. |
139 | 153 |
|
140 | 154 | ## Evaluation
|
141 | 155 |
|
|
0 commit comments