Skip to content

Fix minor docs issues and fix metric requests #21040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions benchmarks/kernels/bench_nvfp4_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
import argparse
import copy
import itertools

Check failure on line 6 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP035)

benchmarks/kernels/bench_nvfp4_gemm.py:6:1: UP035 `typing.List` is deprecated, use `list` instead

Check failure on line 6 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP035)

benchmarks/kernels/bench_nvfp4_gemm.py:6:1: UP035 `typing.Dict` is deprecated, use `dict` instead
import torch
import triton
import triton.language as tl
from typing import Optional, List, Dict
import pandas as pd
from weight_shapes import WEIGHT_SHAPES

from vllm import _custom_ops as ops
from vllm.platforms import current_platform

Check failure on line 15 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F811)

benchmarks/kernels/bench_nvfp4_gemm.py:15:31: F811 Redefinition of unused `triton` from line 9
from vllm.scalar_type import scalar_types
from vllm.triton_utils import triton

Expand All @@ -28,6 +32,104 @@
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]


def benchmark_nvfp4_gemm_enhanced(
m: int,
n: int,
k: int,
use_cutlass: bool = True,
use_triton: bool = True,

Check failure on line 40 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

benchmarks/kernels/bench_nvfp4_gemm.py:40:11: UP006 Use `dict` instead of `Dict` for type annotation

Check failure on line 40 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

benchmarks/kernels/bench_nvfp4_gemm.py:40:6: UP006 Use `list` instead of `List` for type annotation
dtype: torch.dtype = torch.float16
) -> List[Dict]:
results = []

device = torch.device('cuda')
compute_capability = torch.cuda.get_device_capability(device)
sm_version = compute_capability[0] * 10 + compute_capability[1]

print(f"Device SM version: {sm_version}")

a = torch.randn(m, k, device=device, dtype=dtype)
b = torch.randn(k, n, device=device, dtype=dtype)

if use_cutlass and sm_version == 100:
try:
from vllm.experimental.kernels import cutlass_fp8_gemm

for _ in range(10):
out_cutlass = cutlass_fp8_gemm(a, b)

torch.cuda.synchronize()

import time
num_iterations = 100
start = time.time()

Check failure on line 66 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

benchmarks/kernels/bench_nvfp4_gemm.py:66:17: F841 Local variable `out_cutlass` is assigned to but never used
for _ in range(num_iterations):
out_cutlass = cutlass_fp8_gemm(a, b)

torch.cuda.synchronize()
end = time.time()

elapsed_ms = (end - start) * 1000 / num_iterations
flops = 2 * m * n * k
tflops = flops / (elapsed_ms / 1000) / 1e12

results.append({
'implementation': 'CUTLASS',
'elapsed_ms': elapsed_ms,
'tflops': tflops,
'm': m,
'n': n,
'k': k,
'sm_version': sm_version
})

except ImportError:
print("CUTLASS implementation not available")

Check failure on line 89 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F821)

benchmarks/kernels/bench_nvfp4_gemm.py:89:25: F821 Undefined name `benchmark_triton_nvfp4_gemm`
if use_triton:
triton_result = benchmark_triton_nvfp4_gemm(m, n, k, dtype)
triton_result.update({
'implementation': 'Triton',
'm': m,
'n': n,
'k': k,
'sm_version': sm_version
})
results.append(triton_result)

for _ in range(10):
out_ref = torch.matmul(a, b)

torch.cuda.synchronize()

import time
num_iterations = 100
start = time.time()

Check failure on line 109 in benchmarks/kernels/bench_nvfp4_gemm.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

benchmarks/kernels/bench_nvfp4_gemm.py:109:9: F841 Local variable `out_ref` is assigned to but never used
for _ in range(num_iterations):
out_ref = torch.matmul(a, b)

torch.cuda.synchronize()
end = time.time()

elapsed_ms = (end - start) * 1000 / num_iterations
flops = 2 * m * n * k
tflops = flops / (elapsed_ms / 1000) / 1e12

results.append({
'implementation': 'PyTorch (FP16)',
'elapsed_ms': elapsed_ms,
'tflops': tflops,
'm': m,
'n': n,
'k': k,
'sm_version': sm_version
})

return results


def _quant_weight_nvfp4(b: torch.Tensor, device: str):
# Compute global scale for weight
b_amax = torch.abs(b).max().to(torch.float32)
Expand Down
4 changes: 2 additions & 2 deletions docs/community/meetups.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Meetups

We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:
We host regular meetups in the San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below:

- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing)
- [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing).
Expand All @@ -17,4 +17,4 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share
- [The second vLLM meetup](https://lu.ma/ygxbpzhl), with IBM Research, January 31st 2024. [[Slides]](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing) [[Video (vLLM Update)]](https://youtu.be/Y0C-DUvEnZQ) [[Video (IBM Research & torch.compile)]](https://youtu.be/m0dMtFLI-dg)
- [The first vLLM meetup](https://lu.ma/first-vllm-meetup), with a16z, October 5th 2023. [[Slides]](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing)

We are always looking for speakers and sponsors at San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at [[email protected]](mailto:[email protected]).
We are always looking for speakers and sponsors in the San Francisco Bay Area and potentially other locations. If you are interested in speaking or sponsoring, please contact us at [[email protected]](mailto:[email protected]).
2 changes: 1 addition & 1 deletion docs/models/extensions/fastsafetensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ Loading Model weights with fastsafetensors
===================================================================

Using fastsafetensors library enables loading model weights to GPU memory by leveraging GPU direct storage. See [their GitHub repository](https://github.com/foundation-model-stack/fastsafetensors) for more details.
For enabling this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
To enable this feature, set the environment variable ``USE_FASTSAFETENSOR`` to ``true``
10 changes: 10 additions & 0 deletions vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import time
import asyncio
from collections.abc import Iterable
from dataclasses import dataclass
Expand All @@ -10,6 +11,7 @@

from vllm.outputs import (CompletionOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput)
from vllm.sequence import RequestMetrics
from vllm.sampling_params import RequestOutputKind
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -410,6 +412,14 @@
if request_output := req_state.make_request_output(
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params, num_cached_tokens):
request_output.metrics = RequestMetrics(
arrival_time=req_state.stats.arrival_time,
last_token_time=req_state.stats.last_token_ts,
first_scheduled_time=req_state.stats.scheduled_ts,

Check failure on line 418 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[RequestStateStats]" has no attribute "arrival_time" [union-attr]
first_token_time=req_state.stats.first_token_ts,

Check failure on line 419 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Item "None" of "Optional[RequestStateStats]" has no attribute "last_token_ts" [union-attr]
time_in_queue=req_state.stats.scheduled_ts - req_state.stats.arrival_time,
finished_time=time.monotonic()
)
Comment on lines +415 to +422
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The new code to populate request_output.metrics assumes that req_state.stats is always available. However, req_state.stats is initialized to None if log_stats is False (see RequestState.__init__).

This will lead to an AttributeError when trying to access req_state.stats.arrival_time, causing a crash when log_stats is disabled.

To prevent this, you should add a check to ensure req_state.stats is not None before attempting to access its attributes.

Suggested change
request_output.metrics = RequestMetrics(
arrival_time=req_state.stats.arrival_time,
last_token_time=req_state.stats.last_token_ts,
first_scheduled_time=req_state.stats.scheduled_ts,
first_token_time=req_state.stats.first_token_ts,
time_in_queue=req_state.stats.scheduled_ts - req_state.stats.arrival_time,
finished_time=time.monotonic()
)
if req_state.stats:
request_output.metrics = RequestMetrics(
arrival_time=req_state.stats.arrival_time,
last_token_time=req_state.stats.last_token_ts,
first_scheduled_time=req_state.stats.scheduled_ts,
first_token_time=req_state.stats.first_token_ts,
time_in_queue=req_state.stats.scheduled_ts - req_state.stats.arrival_time,
finished_time=time.monotonic()
)

if req_state.queue is not None:
# AsyncLLM: put into queue for handling by generate().
req_state.queue.put(request_output)
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def process_inputs(
f"is out of range [0, {data_parallel_size}).")

if arrival_time is None:
arrival_time = time.time()
arrival_time = time.monotonic()

# Process inputs, which includes:
# 1. Tokenize text prompt, with LoRA request if one exists.
Expand Down