You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-03-10-sampling.md
+27-23Lines changed: 27 additions & 23 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,20 +1,19 @@
1
1
---
2
2
layout: post
3
-
title: "Sorting-Free Rejection Sampling GPU-Kernels in FlashInfer for Faster Inference"
3
+
title: "Sorting-Free GPU-Kernels for Categorical Sampling Under Filtering in LLM Inference"
4
4
date: 2025-03-10
5
5
comments: true
6
6
author: Shanli Xing (UW), Zihao Ye (UW), Bohan Hou (CMU), Luis Ceze (UW), Tianqi Chen (CMU)
7
7
---
8
8
9
9
## Background
10
10
11
-
As vocabulary size grows in Large Language Models (LLMs), the sampling (token selection) process becomes a performance bottleneck. Sampling is key operator in LLM Inference Serving, the [sampling operators](https://docs.flashinfer.ai/api/sampling.html) in FlashInfer were first introduced in [v0.0.5](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5)
12
-
and FlashInfer team has been improving the robustness and performance of the sampling operators since then. In this blog, we'll walk you through the algorithm and implementation details of sampling operators in FlashInfer.
11
+
As vocabulary sizes grow larger in Large Language Models (LLMs), categorical sampling (token selection) has emerged as a significant performance bottleneck in LLM inference serving. The [sampling operators](https://docs.flashinfer.ai/api/sampling.html) in FlashInfer were first introduced in [v0.0.5](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.0.5), and since then, the FlashInfer team has continuously improved their robustness and performance. In this blog post, we'll explore the algorithms and implementation details behind FlashInfer's sampling operators.
13
12
14
13
## LLM Sampling
15
14
16
-
Sampling is the process that picks a specific next token from the vector of model logits (one per token). In practice, heuristics such as Top-P, Top-K, or Min-P thresholds are usually applied to pass tokens with negligible probability, control generation behaviors, and enforce minimum probabilities.
17
-
15
+
Categorical Sampling is the process that picks a specific next token from model output probabilities (over the vocabulary).
16
+
In practice, filtering is applied before sampling to pass tokens with negligible probability, control generation behaviors, and enforce minimum probabilities, such as Top-P, Top-K, or Min-P thresholds:
18
17
19
18
<palign="center">
20
19
<imgsrc="/assets/imgs/sampling_blog/Sampling_Portion.png"alt="The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models."width="800"/>
@@ -37,10 +36,7 @@ Sampling is the process that picks a specific next token from the vector of mode
37
36
is parameter and $p_\text{max}$ is the largest probability in the inputs. This helps eliminate extremely unlikely tokens while still respecting relative differences among the top candidates.
38
37
39
38
40
-
In practice, the combination of Top-K and Top-P is popular and used as the standard setting for LLM sampling. This allows for finer-grained control over the generation process. For example if we use the Top-K first filtering, we first limit the token set to the Top-K highest probabilities, and then apply a Top-P cutoff to filter the tail portion within those $K$ tokens.
41
-
42
-
> *FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the*[doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html)*.*
43
-
>
39
+
In practice, the combination of Top-K and Top-P filtering is popular and used as the standard setting for LLM sampling. This allows for finer-grained control over the generation process. For example if we use the Top-K first filtering, we first limit the token set to the Top-K highest probabilities, and then apply a Top-P cutoff to filter the tail portion within those $K$ tokens. [^1]
44
40
45
41
A PyTorch implementation of these samplers might look like this:
46
42
@@ -78,9 +74,10 @@ def _apply_top_k_top_p(
78
74
return logits
79
75
```
80
76
81
-
This code uses a combination of sorting, cumulative sums, and masking. While it is straightforward to follow, it induces performance bottleneck especially for large vocab size.
77
+
This code uses a combination of sorting, cumulative sums, and masking.
78
+
While it is straightforward to follow, it induces performance bottleneck especially for large vocab size, because of the huge overhead of sorting.
82
79
83
-
In FlashInfer, we introduce the **Dual Pivot Rejection Sampling** algorithm and design multiple fused sampling kernels to fully leverage GPUs' parallel computing capabilities, ultimately achieving logarithmic time output. In this blog, we'll walk you through how we developed this algorithm integrating ideas from Inverse Sampling, Rejection Sampling, and binary search.
80
+
In FlashInfer, we show that sampling under filtering can be done in sorting-free manner, and we introduce the **Dual Pivot Rejection Sampling** algorithm and design fused sampling kernel templates to fully leverage GPUs' parallel computing capabilities, ultimately achieving logarithmic (in worst case) time complexity. In this blog, we'll walk you through how we developed this algorithm integrating ideas from Inverse Sampling, Rejection Sampling, and final version of the algorithm with theorerical guarantee of convergence.
84
81
85
82
## Algorithm
86
83
@@ -97,44 +94,49 @@ We begin with implementing a basic sampling kernel that selects tokens purely ba
97
94
The method is **inverse transform sampling**, which draws samples from a probability distribution given its cumulative distribution function (CDF). As for the token samling process, the CDF would be the prefix sum of token probabilities. The algorithm proceeds like this:
98
95
99
96
1.**Draw a random** $u$ from $U\sim \text{Unif}(0,1)$.
100
-
2.**Compute the partial sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
101
-
3.**Identify the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
97
+
2.**Compute the prefix sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
98
+
3.**Locate the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
102
99
103
-
Implementation side, the 2. and 3. parts are orchestrated for better parallelism: we scan the tokens in blocks, and within each block we:
100
+
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to fixed number of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
104
101
105
-
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ . If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
102
+
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ for each block. If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
106
103
2. If not, we add $\texttt{a\_local}$ to $\texttt{a}$ and move on to the next block.
107
104
3. Once we know the correct block, we perform a prefix sum over its tokens to pinpoint the exact token index.
108
105
109
-
The per-block partial sum and prefix sums are computed leveraging [CUB collective primitives](https://docs.nvidia.com/cuda/cub/index.html) (now part of [CCCL](https://github.com/NVIDIA/cccl)) like [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE)and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) to maximize efficiency.
106
+
We use [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE) and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) for the per-block partial sum and prefix sums, and [AdjacentDifference](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockAdjacentDifference.html?highlight=adjacentdifference#_CPPv4I0_i_i_iEN3cub23BlockAdjacentDifferenceE) to locate the token index.
110
107
111
108
### Rejection Sampling
112
109
113
-
114
110
<palign="center">
115
111
<imgsrc="/assets/imgs/sampling_blog/Rejection_Sampling.gif"alt="Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks."width="800"/>
116
112
<br>
117
113
<spanstyle="color: gray; font-style: italic;">Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
118
114
</p>
119
115
120
-
121
116
For more advanced strategies such as **Top-P sampling**, we use **rejection sampling** to restrict which tokens can be selected. Rejection sampling draws from a target distribution by comparing random samples against a threshold and discarding those that do not meet it.
122
117
123
-
Taking the Top-P sampling kernel as an example, here is a simplified look pf what happens:
118
+
Taking the sampling kernel under Top-P filtering as an example, here is a simplified look of what happens:
124
119
125
120
1.**Initialize the pivot** to $0$, so initially all tokens are considered.
126
121
2.**Perform an inverse transform sampling pass** but ignoring tokens with probabilities below the current pivot. After sampling a token, **update the pivot** to that token’s probability.
127
122
3.**Compute the remaining probability** $\texttt{q}$ among tokens that still exceed this pivot:
128
123
1. If $\texttt{q}$ remains greater than or equal to $\texttt{top\_p}$, another round is needed to raise the pivot further and reject more tokens.
129
124
2. Otherwise, if it is below $\texttt{top\_p}$, we finalize the sampled token and mark success.
130
-
4.**Repeat** until successful or until a preset limit $\texttt{max\_top\_p\_rounds}$ is reached.
125
+
4.**Repeat** until successful.
131
126
132
-
This algorithms works similar for Top-K, other than we’ll be checking the number of tokens exceeding the pivot against $\texttt{top\_k}$.
127
+
The whole algorithm can be implemented in a single fused kernel, and it works similar for Top-K and other filtering strategies, other than we’ll be checking the number of tokens exceeding the pivot against $\texttt{top\_k}$ or $\texttt{min\_p}$.
133
128
134
-
In practice, we find that this process rarely exceeds two rounds. It provides a substantial speedup by saving redundant CPU-GPU communications and on-CPU computations compared to native PyTorch implementations.
129
+
In practice, we find that the number of rounds for returning a sampled token is usually small. It provides a substantial speedup compared to the naive PyTorch implementation because we avoid the sorting and multiple passes over the vocabulary, as well as multiple kernel launch overheads.
135
130
136
131
### Dual Pivot Rejection Sampling
137
132
133
+
While this rejection sampling approach is simple and efficient in most cases, it has some limitations. There is no theoretical guarantee on the number of rounds needed to obtain a sampled token. This can lead to varying sampling times across different probability distributions, which in turn causes inconsistent inter-token latency during LLM inference serving. Such variability may impact the predictability and reliability of the serving system.
134
+
135
+
To address this issue, in FlashInfer [v0.2.3](https://github.com/flashinfer-ai/flashinfer/releases/tag/v0.2.3), we introduce the a new algorithm called **Dual Pivot Rejection Sampling**, which uses two pivots for faster convergence in rejection sampling. The algorithm is as follows:
136
+
137
+
138
+
139
+
138
140
## Evaluation
139
141
140
142
<palign="center">
@@ -147,4 +149,6 @@ In practice, we find that this process rarely exceeds two rounds. It provides a
147
149
<imgsrc="/assets/imgs/sampling_blog/Sampling_Latency_Growth_with_Batch_Size.svg"alt="Sampling Latency Growth with Batch Size."width="800"/>
148
150
<br>
149
151
<spanstyle="color: gray; font-style: italic;">Throughput Comparison of Different Engine Kernel.</span>
150
-
</p>
152
+
</p>
153
+
154
+
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).
0 commit comments