Skip to content

Commit 1c002e1

Browse files
committed
Add turns support to synthetic dataset
1 parent 1f295f4 commit 1c002e1

File tree

1 file changed

+71
-32
lines changed

1 file changed

+71
-32
lines changed

src/guidellm/dataset/synthetic.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Iterable, Iterator
44
from itertools import cycle
55
from pathlib import Path
6-
from typing import Any, Literal, Optional, Union
6+
from typing import Any, Optional, TypedDict, Union
77

88
import yaml
99
from datasets import (
@@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel):
6969
gt=0,
7070
default=None,
7171
)
72+
turns: int = Field(
73+
description="The number of turns in the conversation.",
74+
gt=0,
75+
default=1,
76+
)
77+
turns_stdev: Optional[int] = Field(
78+
description="The standard deviation of the number of turns.",
79+
gt=0,
80+
default=None,
81+
)
82+
turns_min: Optional[int] = Field(
83+
description="The minimum number of turns in the conversation.",
84+
gt=0,
85+
default=None,
86+
)
87+
turns_max: Optional[int] = Field(
88+
description="The maximum number of turns in the conversation.",
89+
gt=0,
90+
default=None,
91+
)
7292
samples: int = Field(
7393
description="The number of samples to generate for the dataset.",
7494
gt=0,
@@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
124144
return SyntheticDatasetConfig(**config_dict)
125145

126146

127-
class SyntheticTextItemsGenerator(
128-
Iterable[
129-
dict[
130-
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
131-
Union[str, int],
132-
]
133-
]
134-
):
147+
class SyntheticDatasetRow(TypedDict):
148+
prompt: list[str]
149+
prompt_tokens_count: list[int]
150+
output_tokens_count: list[int]
151+
152+
153+
class SyntheticTextItemsGenerator(Iterable[SyntheticDatasetRow]):
135154
def __init__(
136155
self,
137156
config: SyntheticDatasetConfig,
@@ -147,12 +166,7 @@ def __init__(
147166

148167
def __iter__(
149168
self,
150-
) -> Iterator[
151-
dict[
152-
Literal["prompt", "prompt_tokens_count", "output_tokens_count"],
153-
Union[str, int],
154-
]
155-
]:
169+
) -> Iterator[SyntheticDatasetRow]:
156170
prompt_tokens_sampler = IntegerRangeSampler(
157171
average=self.config.prompt_tokens,
158172
variance=self.config.prompt_tokens_stdev,
@@ -167,31 +181,56 @@ def __iter__(
167181
max_value=self.config.output_tokens_max,
168182
random_seed=self.random_seed + 1, # ensure diff dist from prompts
169183
)
184+
turns_sampler = IntegerRangeSampler(
185+
average=self.config.turns,
186+
variance=self.config.turns_stdev,
187+
min_value=self.config.turns_min,
188+
max_value=self.config.turns_max,
189+
random_seed=self.random_seed + 7, # ensure diff dist
190+
)
170191
# ensure diff distribution from output tokens
171192
rand = random.Random(self.random_seed + 2) # noqa: S311
172193
unique_prefix_iter = cycle(self.processor.get_vocab().values())
173194

174195
prefix_index = rand.randint(0, len(self.text_creator.words))
175196
prefix_tokens = self._create_prompt(self.config.prefix_tokens, prefix_index)
176197

177-
for _, prompt_tokens, output_tokens in zip(
178-
range(self.config.samples),
179-
prompt_tokens_sampler,
180-
output_tokens_sampler,
181-
):
182-
start_index = rand.randint(0, len(self.text_creator.words))
183-
prompt_text = self.processor.decode(
184-
prefix_tokens
185-
+ self._create_prompt(
186-
prompt_tokens, start_index, next(unique_prefix_iter)
187-
),
188-
skip_special_tokens=True,
189-
)
190-
yield {
191-
"prompt": prompt_text,
192-
"prompt_tokens_count": self.config.prefix_tokens + prompt_tokens,
193-
"output_tokens_count": output_tokens,
198+
for _, turns in zip(range(self.config.samples), turns_sampler):
199+
row: SyntheticDatasetRow = {
200+
"prompt": [],
201+
"prompt_tokens_count": [],
202+
"output_tokens_count": [],
194203
}
204+
for i, prompt_tokens, output_tokens in zip(
205+
range(turns),
206+
prompt_tokens_sampler,
207+
output_tokens_sampler,
208+
):
209+
start_index = rand.randint(0, len(self.text_creator.words))
210+
# Append the prefix tokens only for the first turn
211+
if i == 0:
212+
prompt_text = self.processor.decode(
213+
prefix_tokens
214+
+ self._create_prompt(
215+
prompt_tokens, start_index, next(unique_prefix_iter)
216+
),
217+
skip_special_tokens=True,
218+
)
219+
row["prompt"].append(prompt_text)
220+
row["prompt_tokens_count"].append(self.config.prefix_tokens + prompt_tokens)
221+
row["output_tokens_count"].append(output_tokens)
222+
else:
223+
prompt_text = self.processor.decode(
224+
self._create_prompt(
225+
prompt_tokens, start_index, next(unique_prefix_iter)
226+
),
227+
skip_special_tokens=True,
228+
)
229+
row["prompt"].append(prompt_text)
230+
row["prompt_tokens_count"].append(prompt_tokens)
231+
row["output_tokens_count"].append(output_tokens)
232+
233+
yield row
195234

196235
def _create_prompt(
197236
self, prompt_tokens: int, start_index: int, unique_prefix: Optional[int] = None

0 commit comments

Comments
 (0)