Skip to content

Commit 0e98a95

Browse files
committed
upd
1 parent 0070ef5 commit 0e98a95

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

_posts/2025-03-10-sampling.md

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ In practice, filtering is applied before sampling to pass tokens with negligible
1818
<p align="center">
1919
<img src="/assets/imgs/sampling_blog/Sampling_Portion.png" alt="The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models." width="800"/>
2020
<br>
21-
<span style="color: gray; font-style: italic;">The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models.</span>
21+
<span style="color: gray; font-style: italic;">Figure 1: The compute time break down highlighting the sampling process. In the vLLM 1xH100 configuration, our kernels reduce the overall sampling time by more than 50% across all three models.</span>
2222
</p>
2323

2424

@@ -86,7 +86,7 @@ In FlashInfer, we show that sampling under filtering can be done in sorting-free
8686
<p align="center">
8787
<img src="/assets/imgs/sampling_blog/Inverse_Sampling.gif" alt="Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks." width="800"/>
8888
<br>
89-
<span style="color: gray; font-style: italic;">Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks.</span>
89+
<span style="color: gray; font-style: italic;">Figure 2: Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks.</span>
9090
</p>
9191

9292
We begin with implementing a basic sampling kernel that selects tokens purely based on their probabilities, particularly in the GPU parallel computing context.
@@ -97,20 +97,21 @@ The method is **inverse transform sampling**, which draws samples from a probabi
9797
2. **Compute the prefix sums** (CDF) for each sampled token $j$ with probability $p_j$: $F_j=\sum^{j}_{i=1}p_i$.
9898
3. **Locate the token** $k$ such that $F_{k-1} \leq u < F_k$ as the result.
9999

100-
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to fixed number of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
100+
NVIDIA's [CUB](https://docs.nvidia.com/cuda/cub/index.html) library (now part of [CCCL](https://github.com/NVIDIA/cccl)) provides efficient primitives for parallel computing, and we leverage the reduce and scan primitives to compute the prefix sums. We use one threadblock for each probability distribution, for batch sampling, we launch multiple threadblocks in parallel. Block-level reduce/scan primitives can be applied to a block of elements (`BLOCK_SIZE = NUM_THREADS * NUM_ELEMENTS_PER_THREADS`, e.g. 1024 * 4 for float input), for vocabulary size greater than `BLOCK_SIZE`, we split the vocabulary into multiple blocks and sequentially apply the same procedure on each block:
101101

102102
1. Initialize a running total $\texttt{a}=0$. Compute the probability sum $\texttt{a\_local}$ for each block. If $\texttt{a} + \texttt{a\_local}> u$, the sampled token lies in this block.
103103
2. If not, we add $\texttt{a\_local}$ to $\texttt{a}$ and move on to the next block.
104104
3. Once we know the correct block, we perform a prefix sum over its tokens to pinpoint the exact token index.
105105

106106
We use [BlockReduce](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockReduce.html#_CPPv4I0_i_20BlockReduceAlgorithm_i_iEN3cub11BlockReduceE) and [BlockScan](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockScan.html#_CPPv4I0_i_18BlockScanAlgorithm_i_iEN3cub9BlockScanE) for the per-block partial sum and prefix sums, and [AdjacentDifference](https://nvidia.github.io/cccl/cub/api/classcub_1_1BlockAdjacentDifference.html?highlight=adjacentdifference#_CPPv4I0_i_i_iEN3cub23BlockAdjacentDifferenceE) to locate the token index.
107+
In practice, we use early-stopping to terminate the inverse transform sampling process when the cumulative probability exceeds the random number $u$, so we don't need to go through the whole vocabulary for each round.
107108

108109
### Rejection Sampling
109110

110111
<p align="center">
111112
<img src="/assets/imgs/sampling_blog/Rejection_Sampling.gif" alt="Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks." width="800"/>
112113
<br>
113-
<span style="color: gray; font-style: italic;">Inverse Transform Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
114+
<span style="color: gray; font-style: italic;">Figure 3: Top-P Rejection Sampling. This animation illustrates the per-block process, and in practice the workload gets executed by blocks. </span>
114115
</p>
115116

116117
For more advanced strategies such as **Top-P sampling**, we use **rejection sampling** to restrict which tokens can be selected. Rejection sampling draws from a target distribution by comparing random samples against a threshold and discarding those that do not meet it.
@@ -153,16 +154,28 @@ Figure 4 shows the transition from round(i) to round(i+1) in Dual Pivot Rejectio
153154

154155
## Evaluation
155156

157+
Our evaluation demonstrates that FlashInfer's sampling kernel delivers substantial improvements in both kernel-level latency and end-to-end throughput compared to traditional sorting-based implementations.
158+
156159
<p align="center">
157160
<img src="/assets/imgs/sampling_blog/Throughput_Comparison_of_Different_Engine_Kernel.svg" alt="Throughput Comparison of Different Engine Kernel" width="800"/>
158161
<br>
159-
<span style="color: gray; font-style: italic;">Throughput Comparison of Different Engine Kernel.</span>
162+
<span style="color: gray; font-style: italic;">Figure 5: Throughput Comparison of Different Engine Kernel.</span>
160163
</p>
161164

162165
<p align="center">
163166
<img src="/assets/imgs/sampling_blog/Sampling_Latency_Growth_with_Batch_Size.svg" alt="Sampling Latency Growth with Batch Size." width="800"/>
164167
<br>
165-
<span style="color: gray; font-style: italic;">Throughput Comparison of Different Engine Kernel.</span>
168+
<span style="color: gray; font-style: italic;">Figure 6: Sampling Latency Growth with Batch Size.</span>
166169
</p>
167170

171+
## Community Adoption and Other Applications
172+
The FlashInfer sampling kernel has been widely adopted by several prominent frameworks, including [sglang](https://github.com/sgl-project/sglang) and [vLLM](https://github.com/vllm-project/vllm/pull/7137). We are grateful for the community's valuable feedback and bug reports that have helped improve the implementation. Beyond sampling, the core ideas behind our approach have broader applications, particularly in speculative decoding verification. This includes techniques like [chain speculative sampling](https://arxiv.org/pdf/2302.01318) and [tree speculative verification](https://arxiv.org/pdf/2305.09781).
173+
174+
## Implementation Details
175+
176+
While the algorithm is elegant in theory, implementing it efficiently in a GPU kernel requires careful attention to detail, particularly in the token selection logic in inverse transform sampling. One key challenge lies in the parallel prefix-sum operation used to locate sampled tokens. Due to the non-associative and non-commutative nature of floating-point arithmetic, parallel prefix-sum **cannot guarantee monotonic outputs** even with non-negative inputs. This can lead to invalid token generation if not handled properly. Special care must be taken to ensure numerical stability and correctness in the sampling implementation (and we made a lot of mistakes before got it right)
177+
178+
For the complete implementation details, including how we address these challenges, please refer to the [source code](https://github.com/flashinfer-ai/flashinfer/blob/main/include/flashinfer/sampling.cuh).
179+
180+
## Footnotes
168181
[^1]: FlashInfer provides both "Top-K First" and "Joint" filtering options, with the latter applying Top-K and Top-P simultaneously at each round. More on the [doc](https://docs.flashinfer.ai/generated/flashinfer.sampling.top_k_top_p_sampling_from_probs.html).

0 commit comments

Comments
 (0)