Skip to content

Commit bfcb4d2

Browse files
committed
Add helper for converting literals to list of strings
1 parent 427caea commit bfcb4d2

File tree

5 files changed

+177
-8
lines changed

5 files changed

+177
-8
lines changed

src/guidellm/__main__.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import codecs
33
from pathlib import Path
4-
from typing import get_args
4+
from typing import Union
55

66
import click
77

@@ -19,12 +19,10 @@
1919
from guidellm.config import print_config
2020
from guidellm.preprocess.dataset import ShortPromptStrategy, process_dataset
2121
from guidellm.scheduler import StrategyType
22-
from guidellm.utils import DefaultGroupHandler
22+
from guidellm.utils import DefaultGroupHandler, get_literal_vals
2323
from guidellm.utils import cli as cli_tools
2424

25-
STRATEGY_PROFILE_CHOICES = list(
26-
set(list(get_args(ProfileType)) + list(get_args(StrategyType)))
27-
)
25+
STRATEGY_PROFILE_CHOICES = list(get_literal_vals(Union[ProfileType, StrategyType]))
2826

2927

3028
@click.group()
@@ -93,10 +91,10 @@ def benchmark():
9391
"--backend",
9492
"--backend-type", # legacy alias
9593
"backend",
96-
type=click.Choice(list(get_args(BackendType))),
94+
type=click.Choice(list(get_literal_vals(BackendType))),
9795
help=(
9896
"The type of backend to use to run requests against. Defaults to 'openai_http'."
99-
f" Supported types: {', '.join(get_args(BackendType))}"
97+
f" Supported types: {', '.join(get_literal_vals(BackendType))}"
10098
),
10199
default="openai_http",
102100
)

src/guidellm/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
split_text_list_by_length,
6262
)
6363
from .threading import synchronous_to_exitable_async
64+
from .typing import get_literal_vals
6465

6566
__all__ = [
6667
"SUPPORTED_TYPES",
@@ -106,6 +107,7 @@
106107
"clean_text",
107108
"filter_text",
108109
"format_value_display",
110+
"get_literal_vals",
109111
"is_puncutation",
110112
"load_text",
111113
"safe_add",

src/guidellm/utils/typing.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Iterator
4+
from typing import Annotated, Literal, Union, get_args, get_origin
5+
6+
# Backwords compatibility for Python <3.10
7+
try:
8+
from types import UnionType # type: ignore[attr-defined]
9+
except ImportError:
10+
UnionType = Union
11+
12+
# Backwords compatibility for Python <3.12
13+
try:
14+
from typing import TypeAliasType # type: ignore[attr-defined]
15+
except ImportError:
16+
from typing_extensions import TypeAliasType
17+
18+
19+
__all__ = ["get_literal_vals"]
20+
21+
22+
def get_literal_vals(alias) -> frozenset[str]:
23+
"""Extract all literal values from a (possibly nested) type alias."""
24+
25+
def resolve(alias) -> Iterator[str]:
26+
origin = get_origin(alias)
27+
28+
# Base case: Literal types
29+
if origin is Literal:
30+
for literal_val in get_args(alias):
31+
yield str(literal_val)
32+
# Unwrap Annotated type
33+
elif origin is Annotated:
34+
yield from resolve(get_args(alias)[0])
35+
# Unwrap TypeAliasTypes
36+
elif isinstance(alias, TypeAliasType):
37+
yield from resolve(alias.__value__)
38+
# Iterate over unions
39+
elif origin in (Union, UnionType):
40+
for arg in get_args(alias):
41+
yield from resolve(arg)
42+
# Fallback
43+
else:
44+
yield str(alias)
45+
46+
return frozenset(resolve(alias))

tests/unit/utils/test_typing.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Test suite for the typing utilities module.
3+
"""
4+
5+
from typing import Annotated, Literal, Union
6+
7+
import pytest
8+
from typing_extensions import TypeAlias
9+
10+
from guidellm.utils.typing import get_literal_vals
11+
12+
# Local type definitions to avoid imports from other modules
13+
LocalProfileType = Literal["synchronous", "async", "concurrent", "throughput", "sweep"]
14+
LocalStrategyType = Annotated[
15+
Literal["synchronous", "concurrent", "throughput", "constant", "poisson"],
16+
"Valid strategy type identifiers for scheduling request patterns",
17+
]
18+
StrategyProfileType: TypeAlias = Union[LocalStrategyType, LocalProfileType]
19+
20+
21+
class TestGetLiteralVals:
22+
"""Test cases for the get_literal_vals function."""
23+
24+
@pytest.mark.sanity
25+
def test_profile_type(self):
26+
"""
27+
Test extracting values from ProfileType.
28+
29+
### WRITTEN BY AI ###
30+
"""
31+
result = get_literal_vals(LocalProfileType)
32+
expected = frozenset(
33+
{"synchronous", "async", "concurrent", "throughput", "sweep"}
34+
)
35+
assert result == expected
36+
37+
@pytest.mark.sanity
38+
def test_strategy_type(self):
39+
"""
40+
Test extracting values from StrategyType.
41+
42+
### WRITTEN BY AI ###
43+
"""
44+
result = get_literal_vals(LocalStrategyType)
45+
expected = frozenset(
46+
{"synchronous", "concurrent", "throughput", "constant", "poisson"}
47+
)
48+
assert result == expected
49+
50+
@pytest.mark.smoke
51+
def test_inline_union_type(self):
52+
"""
53+
Test extracting values from inline union of ProfileType | StrategyType.
54+
55+
### WRITTEN BY AI ###
56+
"""
57+
result = get_literal_vals(Union[LocalProfileType, LocalStrategyType])
58+
expected = frozenset(
59+
{
60+
"synchronous",
61+
"async",
62+
"concurrent",
63+
"throughput",
64+
"constant",
65+
"poisson",
66+
"sweep",
67+
}
68+
)
69+
assert result == expected
70+
71+
@pytest.mark.smoke
72+
def test_type_alias(self):
73+
"""
74+
Test extracting values from type alias union.
75+
76+
### WRITTEN BY AI ###
77+
"""
78+
result = get_literal_vals(StrategyProfileType)
79+
expected = frozenset(
80+
{
81+
"synchronous",
82+
"async",
83+
"concurrent",
84+
"throughput",
85+
"constant",
86+
"poisson",
87+
"sweep",
88+
}
89+
)
90+
assert result == expected
91+
92+
@pytest.mark.sanity
93+
def test_single_literal(self):
94+
"""
95+
Test extracting values from single Literal type.
96+
97+
### WRITTEN BY AI ###
98+
"""
99+
result = get_literal_vals(Literal["test"])
100+
expected = frozenset({"test"})
101+
assert result == expected
102+
103+
@pytest.mark.sanity
104+
def test_multi_literal(self):
105+
"""
106+
Test extracting values from multi-value Literal type.
107+
108+
### WRITTEN BY AI ###
109+
"""
110+
result = get_literal_vals(Literal["test", "test2"])
111+
expected = frozenset({"test", "test2"})
112+
assert result == expected
113+
114+
@pytest.mark.smoke
115+
def test_literal_union(self):
116+
"""
117+
Test extracting values from union of Literal types.
118+
119+
### WRITTEN BY AI ###
120+
"""
121+
result = get_literal_vals(Union[Literal["test", "test2"], Literal["test3"]])
122+
expected = frozenset({"test", "test2", "test3"})
123+
assert result == expected

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ description = Run type checks
6666
deps =
6767
.[dev]
6868
commands =
69-
mypy --check-untyped-defs
69+
mypy --check-untyped-defs {posargs}
7070

7171

7272
[testenv:links]

0 commit comments

Comments
 (0)