You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: _posts/2025-03-10-sampling.md
+19-6Lines changed: 19 additions & 6 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -18,7 +18,7 @@ In practice, filtering is applied before sampling to pass tokens with negligible
18
18
<palign="center">
19
19
<imgsrc="/assets/imgs/sampling_blog/Sampling_Portion.png"alt="The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models."width="800"/>
20
20
<br>
21
-
<spanstyle="color: gray; font-style: italic;">The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models.</span>
21
+
<spanstyle="color: gray; font-style: italic;">Figure 1: The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models.</span>
22
22
</p>
23
23
24
24
@@ -86,7 +86,7 @@ In FlashInfer, we show that sampling under filtering can be done in sorting-free
86
86
<palign="center">
87
87
<imgsrc="/assets/imgs/sampling_blog/Inverse_Sampling.gif"alt="Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks."width="800"/>
88
88
<br>
89
-
<spanstyle="color: gray; font-style: italic;">Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks.</span>
89
+
<spanstyle="color: gray; font-style: italic;">Figure 2: Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks.</span>
90
90
</p>
91
91
92
92
We begin with implementing a basic sampling kernel that selects tokens purely based on their probabilities, particularly in the GPU parallel computing context.
@@ -97,20 +97,21 @@ The method is **inverse transform sampling**, which draws samples from a probabi
97
97
2.**Compute the prefix sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
98
98
3.**Locate the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
99
99
100
-
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to fixed number of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
100
+
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to a block of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
101
101
102
102
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ for each block. If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
103
103
2. If not, we add $\texttt{a\_local}$ to $\texttt{a}$ and move on to the next block.
104
104
3. Once we know the correct block, we perform a prefix sum over its tokens to pinpoint the exact token index.
105
105
106
106
We use [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE) and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) for the per-block partial sum and prefix sums, and [AdjacentDifference](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockAdjacentDifference.html?highlight=adjacentdifference#_CPPv4I0_i_i_iEN3cub23BlockAdjacentDifferenceE) to locate the token index.
107
+
In practice, we use early-stopping to terminate the inverse transform sampling process when the cumulative probability exceeds the random number $u$, so we don't need to go through the whole vocabulary for each round.
107
108
108
109
### Rejection Sampling
109
110
110
111
<palign="center">
111
112
<imgsrc="/assets/imgs/sampling_blog/Rejection_Sampling.gif"alt="Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks."width="800"/>
112
113
<br>
113
-
<spanstyle="color: gray; font-style: italic;">Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
114
+
<spanstyle="color: gray; font-style: italic;">Figure 3: Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
114
115
</p>
115
116
116
117
For more advanced strategies such as **Top-P sampling**, we use **rejection sampling** to restrict which tokens can be selected. Rejection sampling draws from a target distribution by comparing random samples against a threshold and discarding those that do not meet it.
@@ -153,16 +154,28 @@ Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejectio
153
154
154
155
## Evaluation
155
156
157
+
Our evaluation demonstrates that FlashInfer's sampling kernel delivers substantial improvements in both kernel-level latency and end-to-end throughput compared to traditional sorting-based implementations.
158
+
156
159
<palign="center">
157
160
<imgsrc="/assets/imgs/sampling_blog/Throughput_Comparison_of_Different_Engine_Kernel.svg"alt="Throughput Comparison of Different Engine Kernel"width="800"/>
158
161
<br>
159
-
<spanstyle="color: gray; font-style: italic;">Throughput Comparison of Different Engine Kernel.</span>
162
+
<spanstyle="color: gray; font-style: italic;">Figure 5: Throughput Comparison of Different Engine Kernel.</span>
160
163
</p>
161
164
162
165
<palign="center">
163
166
<imgsrc="/assets/imgs/sampling_blog/Sampling_Latency_Growth_with_Batch_Size.svg"alt="Sampling Latency Growth with Batch Size."width="800"/>
164
167
<br>
165
-
<spanstyle="color: gray; font-style: italic;">Throughput Comparison of Different Engine Kernel.</span>
The FlashInfer sampling kernel has been widely adopted by several prominent frameworks, including [sglang](https://github.com/sgl-project/sglang) and [vLLM](https://github.com/vllm-project/vllm/pull/7137). We are grateful for the community's valuable feedback and bug reports that have helped improve the implementation. Beyond sampling, the core ideas behind our approach have broader applications, particularly in speculative decoding verification. This includes techniques like [chain speculative sampling](https://arxiv.org/pdf/2302.01318) and [tree speculative verification](https://arxiv.org/pdf/2305.09781).
173
+
174
+
## Implementation Details
175
+
176
+
While the algorithm is elegant in theory, implementing it efficiently in a GPU kernel requires careful attention to detail, particularly in the token selection logic in inverse transform sampling. One key challenge lies in the parallel prefix-sum operation used to locate sampled tokens. Due to the non-associative and non-commutative nature of floating-point arithmetic, parallel prefix-sum **cannot guarantee monotonic outputs** even with non-negative inputs. This can lead to invalid token generation if not handled properly. Special care must be taken to ensure numerical stability and correctness in the sampling implementation (and we made a lot of mistakes before got it right)
177
+
178
+
For the complete implementation details, including how we address these challenges, please refer to the [source code](https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/sampling.cuh).
179
+
180
+
## Footnotes
168
181
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).
0 commit comments