3
3
from collections .abc import Iterable , Iterator
4
4
from itertools import cycle
5
5
from pathlib import Path
6
- from typing import Any , Literal , Optional , Union
6
+ from typing import Any , Optional , TypedDict , Union
7
7
8
8
import yaml
9
9
from datasets import (
@@ -69,6 +69,26 @@ class SyntheticDatasetConfig(BaseModel):
69
69
gt = 0 ,
70
70
default = None ,
71
71
)
72
+ turns : int = Field (
73
+ description = "The number of turns in the conversation." ,
74
+ gt = 0 ,
75
+ default = 1 ,
76
+ )
77
+ turns_stdev : Optional [int ] = Field (
78
+ description = "The standard deviation of the number of turns." ,
79
+ gt = 0 ,
80
+ default = None ,
81
+ )
82
+ turns_min : Optional [int ] = Field (
83
+ description = "The minimum number of turns in the conversation." ,
84
+ gt = 0 ,
85
+ default = None ,
86
+ )
87
+ turns_max : Optional [int ] = Field (
88
+ description = "The maximum number of turns in the conversation." ,
89
+ gt = 0 ,
90
+ default = None ,
91
+ )
72
92
samples : int = Field (
73
93
description = "The number of samples to generate for the dataset." ,
74
94
gt = 0 ,
@@ -124,14 +144,13 @@ def parse_config_file(data: Union[str, Path]) -> "SyntheticDatasetConfig":
124
144
return SyntheticDatasetConfig (** config_dict )
125
145
126
146
127
- class SyntheticTextItemsGenerator (
128
- Iterable [
129
- dict [
130
- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
131
- Union [str , int ],
132
- ]
133
- ]
134
- ):
147
+ class SyntheticDatasetRow (TypedDict ):
148
+ prompt : list [str ]
149
+ prompt_tokens_count : list [int ]
150
+ output_tokens_count : list [int ]
151
+
152
+
153
+ class SyntheticTextItemsGenerator (Iterable [SyntheticDatasetRow ]):
135
154
def __init__ (
136
155
self ,
137
156
config : SyntheticDatasetConfig ,
@@ -147,12 +166,7 @@ def __init__(
147
166
148
167
def __iter__ (
149
168
self ,
150
- ) -> Iterator [
151
- dict [
152
- Literal ["prompt" , "prompt_tokens_count" , "output_tokens_count" ],
153
- Union [str , int ],
154
- ]
155
- ]:
169
+ ) -> Iterator [SyntheticDatasetRow ]:
156
170
prompt_tokens_sampler = IntegerRangeSampler (
157
171
average = self .config .prompt_tokens ,
158
172
variance = self .config .prompt_tokens_stdev ,
@@ -167,31 +181,56 @@ def __iter__(
167
181
max_value = self .config .output_tokens_max ,
168
182
random_seed = self .random_seed + 1 , # ensure diff dist from prompts
169
183
)
184
+ turns_sampler = IntegerRangeSampler (
185
+ average = self .config .turns ,
186
+ variance = self .config .turns_stdev ,
187
+ min_value = self .config .turns_min ,
188
+ max_value = self .config .turns_max ,
189
+ random_seed = self .random_seed + 7 , # ensure diff dist
190
+ )
170
191
# ensure diff distribution from output tokens
171
192
rand = random .Random (self .random_seed + 2 ) # noqa: S311
172
193
unique_prefix_iter = cycle (self .processor .get_vocab ().values ())
173
194
174
195
prefix_index = rand .randint (0 , len (self .text_creator .words ))
175
196
prefix_tokens = self ._create_prompt (self .config .prefix_tokens , prefix_index )
176
197
177
- for _ , prompt_tokens , output_tokens in zip (
178
- range (self .config .samples ),
179
- prompt_tokens_sampler ,
180
- output_tokens_sampler ,
181
- ):
182
- start_index = rand .randint (0 , len (self .text_creator .words ))
183
- prompt_text = self .processor .decode (
184
- prefix_tokens
185
- + self ._create_prompt (
186
- prompt_tokens , start_index , next (unique_prefix_iter )
187
- ),
188
- skip_special_tokens = True ,
189
- )
190
- yield {
191
- "prompt" : prompt_text ,
192
- "prompt_tokens_count" : self .config .prefix_tokens + prompt_tokens ,
193
- "output_tokens_count" : output_tokens ,
198
+ for _ , turns in zip (range (self .config .samples ), turns_sampler ):
199
+ row : SyntheticDatasetRow = {
200
+ "prompt" : [],
201
+ "prompt_tokens_count" : [],
202
+ "output_tokens_count" : [],
194
203
}
204
+ for i , prompt_tokens , output_tokens in zip (
205
+ range (turns ),
206
+ prompt_tokens_sampler ,
207
+ output_tokens_sampler ,
208
+ ):
209
+ start_index = rand .randint (0 , len (self .text_creator .words ))
210
+ # Append the prefix tokens only for the first turn
211
+ if i == 0 :
212
+ prompt_text = self .processor .decode (
213
+ prefix_tokens
214
+ + self ._create_prompt (
215
+ prompt_tokens , start_index , next (unique_prefix_iter )
216
+ ),
217
+ skip_special_tokens = True ,
218
+ )
219
+ row ["prompt" ].append (prompt_text )
220
+ row ["prompt_tokens_count" ].append (self .config .prefix_tokens + prompt_tokens )
221
+ row ["output_tokens_count" ].append (output_tokens )
222
+ else :
223
+ prompt_text = self .processor .decode (
224
+ self ._create_prompt (
225
+ prompt_tokens , start_index , next (unique_prefix_iter )
226
+ ),
227
+ skip_special_tokens = True ,
228
+ )
229
+ row ["prompt" ].append (prompt_text )
230
+ row ["prompt_tokens_count" ].append (prompt_tokens )
231
+ row ["output_tokens_count" ].append (output_tokens )
232
+
233
+ yield row
195
234
196
235
def _create_prompt (
197
236
self , prompt_tokens : int , start_index : int , unique_prefix : Optional [int ] = None
0 commit comments