Skip to content

Commit 2728022

Browse files
committed
upd
1 parent ab6a58d commit 2728022

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

_posts/2025-03-10-sampling.md

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,19 @@
11
---
22
layout: post
3-
title: "Sorting-Free Rejection Sampling GPU-Kernels in FlashInfer for Faster Inference"
3+
title: "Sorting-Free GPU-Kernels for Categorical Sampling Under Filtering in LLM Inference"
44
date: 2025-03-10
55
comments: true
66
author: Shanli Xing (UW), Zihao Ye (UW), Bohan Hou (CMU), Luis Ceze (UW), Tianqi Chen (CMU)
77
---
88

99
## Background
1010

11-
As vocabulary size grows in Large Language Models (LLMs), the sampling (token selection) process becomes a performance bottleneck. Sampling is key operator in LLM Inference Serving, the [sampling operators](https://docs.flashinfer.ai/api/sampling.html) in FlashInfer were first introduced in [v0.0.5](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5)
12-
and FlashInfer team has been improving the robustness and performance of the sampling operators since then. In this blog, we'll walk you through the algorithm and implementation details of sampling operators in FlashInfer.
11+
As vocabulary sizes grow larger in Large Language Models (LLMs), categorical sampling (token selection) has emerged as a significant performance bottleneck in LLM inference serving. The [sampling operators](https://docs.flashinfer.ai/api/sampling.html) in FlashInfer were first introduced in [v0.0.5](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5), and since then, the FlashInfer team has continuously improved their robustness and performance. In this blog post, we'll explore the algorithms and implementation details behind FlashInfer's sampling operators.
1312

1413
## LLM Sampling
1514

16-
Sampling is the process that picks a specific next token from the vector of model logits (one per token). In practice, heuristics such as Top-P, Top-K, or Min-P thresholds are usually applied to pass tokens with negligible probability, control generation behaviors, and enforce minimum probabilities.
17-
15+
Categorical Sampling is the process that picks a specific next token from model output probabilities (over the vocabulary).
16+
In practice, filtering is applied before sampling to pass tokens with negligible probability, control generation behaviors, and enforce minimum probabilities, such as Top-P, Top-K, or Min-P thresholds:
1817

1918
<p align="center">
2019
<img src="/assets/imgs/sampling_blog/Sampling_Portion.png" alt="The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models." width="800"/>
@@ -37,10 +36,7 @@ Sampling is the process that picks a specific next token from the vector of mode
3736
is parameter and $p_\text{max}$ is the largest probability in the inputs. This helps eliminate extremely unlikely tokens while still respecting relative differences among the top candidates.
3837

3938

40-
In practice, the combination of Top-K and Top-P is popular and used as the standard setting for LLM sampling. This allows for finer-grained control over the generation process. For example if we use the Top-K first filtering, we first limit the token set to the Top-K highest probabilities, and then apply a Top-P cutoff to filter the tail portion within those $K$ tokens.
41-
42-
> *FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the* [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html)*.*
43-
>
39+
In practice, the combination of Top-K and Top-P filtering is popular and used as the standard setting for LLM sampling. This allows for finer-grained control over the generation process. For example if we use the Top-K first filtering, we first limit the token set to the Top-K highest probabilities, and then apply a Top-P cutoff to filter the tail portion within those $K$ tokens. [^1]
4440

4541
A PyTorch implementation of these samplers might look like this:
4642

@@ -78,9 +74,10 @@ def _apply_top_k_top_p(
7874
return logits
7975
```
8076

81-
This code uses a combination of sorting, cumulative sums, and masking. While it is straightforward to follow, it induces performance bottleneck especially for large vocab size.
77+
This code uses a combination of sorting, cumulative sums, and masking.
78+
While it is straightforward to follow, it induces performance bottleneck especially for large vocab size, because of the huge overhead of sorting.
8279

83-
In FlashInfer, we introduce the **Dual Pivot Rejection Sampling** algorithm and design multiple fused sampling kernels to fully leverage GPUs' parallel computing capabilities, ultimately achieving logarithmic time output. In this blog, we'll walk you through how we developed this algorithm integrating ideas from Inverse Sampling, Rejection Sampling, and binary search.
80+
In FlashInfer, we show that sampling under filtering can be done in sorting-free manner, and we introduce the **Dual Pivot Rejection Sampling** algorithm and design fused sampling kernel templates to fully leverage GPUs' parallel computing capabilities, ultimately achieving logarithmic (in worst case) time complexity. In this blog, we'll walk you through how we developed this algorithm integrating ideas from Inverse Sampling, Rejection Sampling, and final version of the algorithm with theorerical guarantee of convergence.
8481

8582
## Algorithm
8683

@@ -97,44 +94,49 @@ We begin with implementing a basic sampling kernel that selects tokens purely ba
9794
The method is **inverse transform sampling**, which draws samples from a probability distribution given its cumulative distribution function (CDF). As for the token samling process, the CDF would be the prefix sum of token probabilities. The algorithm proceeds like this:
9895

9996
1. **Draw a random** $u$ from $U\sim \text{Unif}(0,1)$.
100-
2. **Compute the partial sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
101-
3. **Identify the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
97+
2. **Compute the prefix sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
98+
3. **Locate the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
10299

103-
Implementation side, the 2. and 3. parts are orchestrated for better parallelism: we scan the tokens in blocks, and within each block we:
100+
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to fixed number of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
104101

105-
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ . If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
102+
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ for each block. If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
106103
2. If not, we add $\texttt{a\_local}$ to $\texttt{a}$ and move on to the next block.
107104
3. Once we know the correct block, we perform a prefix sum over its tokens to pinpoint the exact token index.
108105

109-
The per-block partial sum and prefix sums are computed leveraging [CUB collective primitives](https://docs.nvidia.com/cuda/cub/index.html) (now part of [CCCL](https://github.com/NVIDIA/cccl)) like [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE) and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) to maximize efficiency.
106+
We use [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE) and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) for the per-block partial sum and prefix sums, and [AdjacentDifference](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockAdjacentDifference.html?highlight=adjacentdifference#_CPPv4I0_i_i_iEN3cub23BlockAdjacentDifferenceE) to locate the token index.
110107

111108
### Rejection Sampling
112109

113-
114110
<p align="center">
115111
<img src="/assets/imgs/sampling_blog/Rejection_Sampling.gif" alt="Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks." width="800"/>
116112
<br>
117113
<span style="color: gray; font-style: italic;">Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
118114
</p>
119115

120-
121116
For more advanced strategies such as **Top-P sampling**, we use **rejection sampling** to restrict which tokens can be selected. Rejection sampling draws from a target distribution by comparing random samples against a threshold and discarding those that do not meet it.
122117

123-
Taking the Top-P sampling kernel as an example, here is a simplified look pf what happens:
118+
Taking the sampling kernel under Top-P filtering as an example, here is a simplified look of what happens:
124119

125120
1. **Initialize the pivot** to $0$, so initially all tokens are considered.
126121
2. **Perform an inverse transform sampling pass** but ignoring tokens with probabilities below the current pivot. After sampling a token, **update the pivot** to that token’s probability.
127122
3. **Compute the remaining probability** $\texttt{q}$ among tokens that still exceed this pivot:
128123
1. If $\texttt{q}$ remains greater than or equal to $\texttt{top\_p}$, another round is needed to raise the pivot further and reject more tokens.
129124
2. Otherwise, if it is below $\texttt{top\_p}$, we finalize the sampled token and mark success.
130-
4. **Repeat** until successful or until a preset limit $\texttt{max\_top\_p\_rounds}$ is reached.
125+
4. **Repeat** until successful.
131126

132-
This algorithms works similar for Top-K, other than we’ll be checking the number of tokens exceeding the pivot against $\texttt{top\_k}$.
127+
The whole algorithm can be implemented in a single fused kernel, and it works similar for Top-K and other filtering strategies, other than we’ll be checking the number of tokens exceeding the pivot against $\texttt{top\_k}$ or $\texttt{min\_p}$.
133128

134-
In practice, we find that this process rarely exceeds two rounds. It provides a substantial speedup by saving redundant CPU-GPU communications and on-CPU computations compared to native PyTorch implementations.
129+
In practice, we find that the number of rounds for returning a sampled token is usually small. It provides a substantial speedup compared to the naive PyTorch implementation because we avoid the sorting and multiple passes over the vocabulary, as well as multiple kernel launch overheads.
135130

136131
### Dual Pivot Rejection Sampling
137132

133+
While this rejection sampling approach is simple and efficient in most cases, it has some limitations. There is no theoretical guarantee on the number of rounds needed to obtain a sampled token. This can lead to varying sampling times across different probability distributions, which in turn causes inconsistent inter-token latency during LLM inference serving. Such variability may impact the predictability and reliability of the serving system.
134+
135+
To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.3), we introduce the a new algorithm called **Dual Pivot Rejection Sampling**, which uses two pivots for faster convergence in rejection sampling. The algorithm is as follows:
136+
137+
138+
139+
138140
## Evaluation
139141

140142
<p align="center">
@@ -147,4 +149,6 @@ In practice, we find that this process rarely exceeds two rounds. It provides a
147149
<img src="/assets/imgs/sampling_blog/Sampling_Latency_Growth_with_Batch_Size.svg" alt="Sampling Latency Growth with Batch Size." width="800"/>
148150
<br>
149151
<span style="color: gray; font-style: italic;">Throughput Comparison of Different Engine Kernel.</span>
150-
</p>
152+
</p>
153+
154+
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).

0 commit comments

Comments
 (0)