Skip to content

Commit a5597e9

Browse files
committed
upd
1 parent 2728022 commit a5597e9

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

_posts/2025-03-10-sampling.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
layout: post
3-
title: "Sorting-Free GPU-Kernels for Categorical Sampling Under Filtering in LLM Inference"
3+
title: "Sorting-Free GPU Kernels for LLM Sampling"
44
date: 2025-03-10
55
comments: true
66
author: Shanli Xing (UW), Zihao Ye (UW), Bohan Hou (CMU), Luis Ceze (UW), Tianqi Chen (CMU)
@@ -134,8 +134,22 @@ While this rejection sampling approach is simple and efficient in most cases, it
134134

135135
To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.3), we introduce the a new algorithm called **Dual Pivot Rejection Sampling**, which uses two pivots for faster convergence in rejection sampling. The algorithm is as follows:
136136

137+
1. Let $f$ be a function that checks if a probability value is valid: $f(x)=1$ if valid, $0$ if not.
138+
2. Initialize $\textrm{low} \leftarrow 0$ and $\textrm{high} \leftarrow \max_i(p_i)$ as the initial range, it's guaranteed that $f(\textrm{low})=0$ and $f(\textrm{high})=1$.
139+
3. Sample over probability values in the range $(\textrm{low}, \infty)$ using inverse transform sampling.
140+
4. Suppose $j$ is the sampled token, let $\textrm{pivot}_1\leftarrow p_j$, and $\textrm{pivot}_2\leftarrow \frac{\textrm{pivot}_1+\textrm{high}}{2}$.
141+
1. If $f(\textrm{pivot}_1)=1$, we accept the sampled token and return $j$.
142+
2. If $f(\textrm{pivot}_1)=0$, $f(\textrm{pivot}_2)=1$, we set $\textrm{pivot}_1$ as new $\textrm{low}$ and $\textrm{pivot}_2$ as new $\textrm{high}$.
143+
3. If $f(\textrm{pivot}_1)=0$, $f(\textrm{pivot}_2)=0$, we set $\textrm{pivot}_2$ as new $\textrm{low}$.
144+
5. Repeat step 3 and 4 until success.
137145

146+
<p align="center">
147+
<img src="/assets/imgs/sampling_blog/dual-pivot-sampling.png" alt="Dual Pivot Rejection Sampling" width="800"/>
148+
<br>
149+
<span style="color: gray; font-style: italic;">Figure 4: Transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, we either accept the sampled token (case 1) or shrinking the range by at least half (case 2 and 3).</span>
150+
</p>
138151

152+
Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, in each round, if the sampled token is accepted, we return the token, otherwise, the new range's extent is $\frac{\text{high}-\text{pivot}_1}{2} < \frac{\text{high}-\text{low}}{2}$, which is at least half of the previous range. Thus it's guaranteed that the number of rounds is $O(\log(1/\epsilon))$ where $\epsilon$ is the minimal possible value in floating point representation.
139153

140154
## Evaluation
141155

19.9 KB
Loading

0 commit comments

Comments
 (0)