Skip to content

Commit 5d139bb

Browse files
authored
Sampling Blog
Sampling Blog
2 parents 4fc8358 + 0e98a95 commit 5d139bb

7 files changed

+183
-0
lines changed

_posts/2025-03-10-sampling.md

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
---
2+
layout: post
3+
title: "Sorting-Free GPU Kernels for LLM Sampling"
4+
date: 2025-03-10
5+
comments: true
6+
author: Shanli Xing (UW), Zihao Ye (UW), Bohan Hou (CMU), Luis Ceze (UW), Tianqi Chen (CMU)
7+
---
8+
9+
## Background
10+
11+
As vocabulary sizes grow larger in Large Language Models (LLMs), categorical sampling (token selection) has emerged as a significant performance bottleneck in LLM inference serving. The [sampling operators](https://docs.flashinfer.ai/api/sampling.html) in FlashInfer were first introduced in [v0.0.5](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5), and since then, the FlashInfer team has continuously improved their robustness and performance. In this blog post, we'll explore the algorithms and implementation details behind FlashInfer's sampling operators.
12+
13+
## LLM Sampling
14+
15+
Categorical Sampling is the process that picks a specific next token from model output probabilities (over the vocabulary).
16+
In practice, filtering is applied before sampling to pass tokens with negligible probability, control generation behaviors, and enforce minimum probabilities, such as Top-P, Top-K, or Min-P thresholds:
17+
18+
<p align="center">
19+
<img src="/assets/imgs/sampling_blog/Sampling_Portion.png" alt="The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models." width="800"/>
20+
<br>
21+
<span style="color: gray; font-style: italic;">Figure 1: The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models.</span>
22+
</p>
23+
24+
25+
1. **Top-K**
26+
27+
Top-K sampling keeps only the $K$ tokens with the highest probabilities at each generation step. For example, if $K=50$, the model will ignore all tokens outside the top 50 likely candidates.
28+
29+
2. [**Top-P (Nucleus Sampling)**](https://arxiv.org/pdf/1904.09751)
30+
31+
Top-P rather keeps the smallest set of tokens whose cumulative probability just exceeds a threshold $P$. For example, if $P=0.9$, you accumulate token probabilities in descending order until their sum is at least 0.9.
32+
33+
3. [**Min-P**](https://arxiv.org/pdf/2407.01082)
34+
35+
Min-p filters out all tokens below a minimum threashold $p_\text{base} \times p_\text{max}$, where $p_\text{base}$
36+
is parameter and $p_\text{max}$ is the largest probability in the inputs. This helps eliminate extremely unlikely tokens while still respecting relative differences among the top candidates.
37+
38+
39+
In practice, the combination of Top-K and Top-P filtering is popular and used as the standard setting for LLM sampling. This allows for finer-grained control over the generation process. For example if we use the Top-K first filtering, we first limit the token set to the Top-K highest probabilities, and then apply a Top-P cutoff to filter the tail portion within those $K$ tokens. [^1]
40+
41+
A PyTorch implementation of these samplers might look like this:
42+
43+
```python
44+
# vllm/vllm/model_executor/layers/sampler.py
45+
def _apply_top_k_top_p(
46+
logits: torch.Tensor,
47+
p: torch.Tensor,
48+
k: torch.Tensor,
49+
) -> torch.Tensor:
50+
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
51+
52+
# Apply top-k.
53+
top_k_mask = logits_sort.size(1) - k.to(torch.long)
54+
# Get all the top_k values.
55+
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
56+
top_k_mask = logits_sort < top_k_mask
57+
logits_sort.masked_fill_(top_k_mask, -float("inf"))
58+
59+
# Apply top-p.
60+
probs_sort = logits_sort.softmax(dim=-1)
61+
probs_sum = probs_sort.cumsum(dim=-1)
62+
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
63+
# at least one
64+
top_p_mask[:, -1] = False
65+
logits_sort.masked_fill_(top_p_mask, -float("inf"))
66+
67+
# Re-sort the probabilities.
68+
src = torch.arange(logits_idx.shape[-1],
69+
device=logits_idx.device).expand_as(logits_idx)
70+
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
71+
index=logits_idx,
72+
src=src)
73+
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
74+
return logits
75+
```
76+
77+
This code uses a combination of sorting, cumulative sums, and masking.
78+
While it is straightforward to follow, it induces performance bottleneck especially for large vocab size, because of the huge overhead of sorting.
79+
80+
In FlashInfer, we show that sampling under filtering can be done in sorting-free manner, and we introduce the **Dual Pivot Rejection Sampling** algorithm and design fused sampling kernel templates to fully leverage GPUs' parallel computing capabilities, ultimately achieving logarithmic (in worst case) time complexity. In this blog, we'll walk you through how we developed this algorithm integrating ideas from Inverse Sampling, Rejection Sampling, and final version of the algorithm with theorerical guarantee of convergence.
81+
82+
## Algorithm
83+
84+
### Inverse Transform Sampling
85+
86+
<p align="center">
87+
<img src="/assets/imgs/sampling_blog/Inverse_Sampling.gif" alt="Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks." width="800"/>
88+
<br>
89+
<span style="color: gray; font-style: italic;">Figure 2: Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks.</span>
90+
</p>
91+
92+
We begin with implementing a basic sampling kernel that selects tokens purely based on their probabilities, particularly in the GPU parallel computing context.
93+
94+
The method is **inverse transform sampling**, which draws samples from a probability distribution given its cumulative distribution function (CDF). As for the token samling process, the CDF would be the prefix sum of token probabilities. The algorithm proceeds like this:
95+
96+
1. **Draw a random** $u$ from $U\sim \text{Unif}(0,1)$.
97+
2. **Compute the prefix sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
98+
3. **Locate the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
99+
100+
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to a block of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
101+
102+
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ for each block. If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
103+
2. If not, we add $\texttt{a\_local}$ to $\texttt{a}$ and move on to the next block.
104+
3. Once we know the correct block, we perform a prefix sum over its tokens to pinpoint the exact token index.
105+
106+
We use [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE) and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) for the per-block partial sum and prefix sums, and [AdjacentDifference](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockAdjacentDifference.html?highlight=adjacentdifference#_CPPv4I0_i_i_iEN3cub23BlockAdjacentDifferenceE) to locate the token index.
107+
In practice, we use early-stopping to terminate the inverse transform sampling process when the cumulative probability exceeds the random number $u$, so we don't need to go through the whole vocabulary for each round.
108+
109+
### Rejection Sampling
110+
111+
<p align="center">
112+
<img src="/assets/imgs/sampling_blog/Rejection_Sampling.gif" alt="Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks." width="800"/>
113+
<br>
114+
<span style="color: gray; font-style: italic;">Figure 3: Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
115+
</p>
116+
117+
For more advanced strategies such as **Top-P sampling**, we use **rejection sampling** to restrict which tokens can be selected. Rejection sampling draws from a target distribution by comparing random samples against a threshold and discarding those that do not meet it.
118+
119+
Taking the sampling kernel under Top-P filtering as an example, here is a simplified look of what happens:
120+
121+
1. **Initialize the pivot** to $0$, so initially all tokens are considered.
122+
2. **Perform an inverse transform sampling pass** but ignoring tokens with probabilities below the current pivot. After sampling a token, **update the pivot** to that token’s probability.
123+
3. **Compute the remaining probability** $\texttt{q}$ among tokens that still exceed this pivot:
124+
1. If $\texttt{q}$ remains greater than or equal to $\texttt{top\_p}$, another round is needed to raise the pivot further and reject more tokens.
125+
2. Otherwise, if it is below $\texttt{top\_p}$, we finalize the sampled token and mark success.
126+
4. **Repeat** until successful.
127+
128+
The whole algorithm can be implemented in a single fused kernel, and it works similar for Top-K and other filtering strategies, other than we’ll be checking the number of tokens exceeding the pivot against $\texttt{top\_k}$ or $\texttt{min\_p}$.
129+
130+
In practice, we find that the number of rounds for returning a sampled token is usually small. It provides a substantial speedup compared to the naive PyTorch implementation because we avoid the sorting and multiple passes over the vocabulary, as well as multiple kernel launch overheads.
131+
132+
### Dual Pivot Rejection Sampling
133+
134+
While this rejection sampling approach is simple and efficient in most cases, it has some limitations. There is no theoretical guarantee on the number of rounds needed to obtain a sampled token. This can lead to varying sampling times across different probability distributions, which in turn causes inconsistent inter-token latency during LLM inference serving. Such variability may impact the predictability and reliability of the serving system.
135+
136+
To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.3), we introduce the a new algorithm called **Dual Pivot Rejection Sampling**, which uses two pivots for faster convergence in rejection sampling. The algorithm is as follows:
137+
138+
1. Let $f$ be a function that checks if a probability value is valid: $f(x)=1$ if valid, $0$ if not.
139+
2. Initialize $\textrm{low} \leftarrow 0$ and $\textrm{high} \leftarrow \max_i(p_i)$ as the initial range, it's guaranteed that $f(\textrm{low})=0$ and $f(\textrm{high})=1$.
140+
3. Sample over probability values in the range $(\textrm{low}, \infty)$ using inverse transform sampling.
141+
4. Suppose $j$ is the sampled token, let $\textrm{pivot}_1\leftarrow p_j$, and $\textrm{pivot}_2\leftarrow \frac{\textrm{pivot}_1+\textrm{high}}{2}$.
142+
1. If $f(\textrm{pivot}_1)=1$, we accept the sampled token and return $j$.
143+
2. If $f(\textrm{pivot}_1)=0$, $f(\textrm{pivot}_2)=1$, we set $\textrm{pivot}_1$ as new $\textrm{low}$ and $\textrm{pivot}_2$ as new $\textrm{high}$.
144+
3. If $f(\textrm{pivot}_1)=0$, $f(\textrm{pivot}_2)=0$, we set $\textrm{pivot}_2$ as new $\textrm{low}$.
145+
5. Repeat step 3 and 4 until success.
146+
147+
<p align="center">
148+
<img src="/assets/imgs/sampling_blog/dual-pivot-sampling.png" alt="Dual Pivot Rejection Sampling" width="800"/>
149+
<br>
150+
<span style="color: gray; font-style: italic;">Figure 4: Transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, we either accept the sampled token (case 1) or shrinking the range by at least half (case 2 and 3).</span>
151+
</p>
152+
153+
Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejection Sampling, in each round, if the sampled token is accepted, we return the token, otherwise, the new range's extent is $\frac{\text{high}-\text{pivot}_1}{2} < \frac{\text{high}-\text{low}}{2}$, which is at least half of the previous range. Thus it's guaranteed that the number of rounds is $O(\log(1/\epsilon))$ where $\epsilon$ is the minimal possible value in floating point representation.
154+
155+
## Evaluation
156+
157+
Our evaluation demonstrates that FlashInfer's sampling kernel delivers substantial improvements in both kernel-level latency and end-to-end throughput compared to traditional sorting-based implementations.
158+
159+
<p align="center">
160+
<img src="/assets/imgs/sampling_blog/Throughput_Comparison_of_Different_Engine_Kernel.svg" alt="Throughput Comparison of Different Engine Kernel" width="800"/>
161+
<br>
162+
<span style="color: gray; font-style: italic;">Figure 5: Throughput Comparison of Different Engine Kernel.</span>
163+
</p>
164+
165+
<p align="center">
166+
<img src="/assets/imgs/sampling_blog/Sampling_Latency_Growth_with_Batch_Size.svg" alt="Sampling Latency Growth with Batch Size." width="800"/>
167+
<br>
168+
<span style="color: gray; font-style: italic;">Figure 6: Sampling Latency Growth with Batch Size.</span>
169+
</p>
170+
171+
## Community Adoption and Other Applications
172+
The FlashInfer sampling kernel has been widely adopted by several prominent frameworks, including [sglang](https://github.com/sgl-project/sglang) and [vLLM](https://github.com/vllm-project/vllm/pull/7137). We are grateful for the community's valuable feedback and bug reports that have helped improve the implementation. Beyond sampling, the core ideas behind our approach have broader applications, particularly in speculative decoding verification. This includes techniques like [chain speculative sampling](https://arxiv.org/pdf/2302.01318) and [tree speculative verification](https://arxiv.org/pdf/2305.09781).
173+
174+
## Implementation Details
175+
176+
While the algorithm is elegant in theory, implementing it efficiently in a GPU kernel requires careful attention to detail, particularly in the token selection logic in inverse transform sampling. One key challenge lies in the parallel prefix-sum operation used to locate sampled tokens. Due to the non-associative and non-commutative nature of floating-point arithmetic, parallel prefix-sum **cannot guarantee monotonic outputs** even with non-negative inputs. This can lead to invalid token generation if not handled properly. Special care must be taken to ensure numerical stability and correctness in the sampling implementation (and we made a lot of mistakes before got it right)
177+
178+
For the complete implementation details, including how we address these challenges, please refer to the [source code](https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/sampling.cuh).
179+
180+
## Footnotes
181+
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).
332 KB
Loading
976 KB
Loading

assets/imgs/sampling_blog/Sampling_Latency_Growth_with_Batch_Size.svg

Lines changed: 1 addition & 0 deletions
Loading
142 KB
Loading

assets/imgs/sampling_blog/Throughput_Comparison_of_Different_Engine_Kernel.svg

Lines changed: 1 addition & 0 deletions
Loading
19.9 KB
Loading

0 commit comments

Comments
 (0)