Skip to content

Commit 441e62f

Browse files
author
Luke Hinds
authored
Merge pull request #18 from StacklokLabs/add-sys-msg
Make system message optional
2 parents c17d9ad + d3aa919 commit 441e62f

16 files changed

+489
-23
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ just open an issue).
2626
- **YAML Configuration**: Define your generation tasks using YAML configuration files
2727
- **Command Line Interface**: Run generation tasks directly from the command line
2828
- **Push to Hugging Face**: Push the generated dataset to Hugging Face Hub with automatic dataset cards and tags
29+
- **System Message Control**: Choose whether to include system messages in the generated dataset
2930

3031
## Getting Started
3132

@@ -95,6 +96,7 @@ dataset:
9596
num_steps: 5
9697
batch_size: 1
9798
model_name: "ollama/mistral:latest"
99+
sys_msg: true # Include system message in dataset (default: true)
98100
save_as: "basic_prompt_dataset.jsonl"
99101

100102
# Optional Hugging Face Hub configuration
@@ -128,6 +130,7 @@ promptwright start config.yaml \
128130
--tree-depth 3 \
129131
--num-steps 10 \
130132
--batch-size 2 \
133+
--sys-msg true \ # Control system message inclusion (default: true)
131134
--hf-repo username/dataset-name \
132135
--hf-token your-token \
133136
--hf-tags tag1 --hf-tags tag2
@@ -185,6 +188,7 @@ engine = DataEngine(
185188
model_name="ollama/llama3",
186189
temperature=0.9,
187190
max_retries=2,
191+
sys_msg=True, # Include system message in dataset (default: true)
188192
)
189193
)
190194
```
@@ -218,6 +222,7 @@ make all
218222

219223
### Prompt Output Examples
220224

225+
With sys_msg=true (default):
221226
```json
222227
{
223228
"messages": [
@@ -237,6 +242,22 @@ make all
237242
}
238243
```
239244

245+
With sys_msg=false:
246+
```json
247+
{
248+
"messages": [
249+
{
250+
"role": "user",
251+
"content": "Create a descriptive passage about a character discovering their hidden talents."
252+
},
253+
{
254+
"role": "assistant",
255+
"content": "As she stared at the canvas, Emma's fingers hovered above the paintbrushes, as if hesitant to unleash the colors that had been locked within her. The strokes began with bold abandon, swirling blues and greens merging into a mesmerizing dance of light and shadow. With each passing moment, she felt herself becoming the art – her very essence seeping onto the canvas like watercolors in a spring storm. The world around her melted away, leaving only the vibrant symphony of color and creation."
256+
}
257+
]
258+
}
259+
```
260+
240261
## Model Compatibility
241262

242263
The library should work with most LLM models. It has been tested with the

examples/example_basic_prompt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
model_name="ollama/mistral-nemo:latest", # Model name
3030
temperature=0.9, # Higher temperature for more creative variations
3131
max_retries=2, # Retry failed prompts up to 2 times
32+
sys_msg=True, # Include system message in dataset (default: true)
3233
)
3334
)
3435

examples/example_basic_prompt.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ dataset:
2727
batch_size: 1
2828
provider: "ollama" # LLM provider
2929
model: "mistral-nemo:latest" # Model name
30+
sys_msg: true # Include system message in dataset (default: true)
3031
save_as: "basic_prompt_dataset.jsonl"

examples/example_culinary_database.yaml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# Example YAML configuration for basic prompt generation
2-
system_prompt: "You are a helpful assistant. You provide clear and concise answers to user questions."
2+
system_prompt: |
3+
You are a culinary expert who documents recipes and cooking techniques.
4+
Your entries should be detailed, precise, and include both traditional and modern cooking methods.
35
46
topic_tree:
57
args:
6-
root_prompt: "You are a culinary expert who documents recipes and cooking techniques.
7-
Your entries should be detailed, precise, and include both traditional and modern cooking methods."
8+
root_prompt: "Global Cuisine and Cooking Techniques"
89
model_system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
9-
tree_degree: 5 # Different continents
10-
tree_depth: 3 # Deeper tree for more specific topics
10+
tree_degree: 5 # Different cuisine types
11+
tree_depth: 3 # Specific dishes and techniques
1112
temperature: 0.7 # Higher temperature for more creative variations
1213
provider: "ollama" # LLM provider
1314
model: "mistral-nemo:latest" # Model name
@@ -35,4 +36,5 @@ dataset:
3536
batch_size: 1
3637
provider: "ollama" # LLM provider
3738
model: "mistral-nemo:latest" # Model name
39+
sys_msg: true # Include system message in dataset (default: true)
3840
save_as: "culinary_database.jsonl"

examples/example_historic_figures.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ system_prompt: |
55
66
topic_tree:
77
args:
8-
root_prompt: "Capital Cities of the World."
8+
root_prompt: "Notable Historical Figures Across Different Eras and Fields"
99
model_system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
10-
tree_degree: 3 # Different continents
11-
tree_depth: 2 # Deeper tree for more specific topics
12-
temperature: 0.7 # Higher temperature for more creative variations
10+
tree_degree: 4 # Different categories
11+
tree_depth: 3 # Deeper tree for more specific figures
12+
temperature: 0.6 # Balanced temperature for creativity and accuracy
1313
provider: "ollama" # LLM provider
1414
model: "mistral-nemo:latest" # Model name
1515
save_as: "historical_figures_tree.jsonl"
@@ -23,7 +23,7 @@ data_engine:
2323
system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
2424
provider: "ollama" # LLM provider
2525
model: "mistral-nemo:latest" # Model name
26-
temperature: 0.9 # Higher temperature for more creative variations
26+
temperature: 0.7 # Balance between creativity and accuracy
2727
max_retries: 2 # Retry failed prompts up to 2 times
2828

2929
dataset:
@@ -32,4 +32,5 @@ dataset:
3232
batch_size: 1
3333
provider: "ollama" # LLM provider
3434
model: "mistral-nemo:latest" # Model name
35-
save_as: "basic_prompt_dataset.jsonl"
35+
sys_msg: true # Include system message in dataset (default: true)
36+
save_as: "historical_figures_database.jsonl"

examples/example_programming_challenges.py.yaml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ topic_tree:
77
args:
88
root_prompt: "Programming Challenges Across Different Difficulty Levels and Concepts"
99
model_system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
10-
tree_degree: 3 # Different continents
11-
tree_depth: 2 # Deeper tree for more specific topics
12-
temperature: 0.7 # Higher temperature for more creative variations
10+
tree_degree: 4 # Different programming concepts
11+
tree_depth: 2 # Various difficulty levels
12+
temperature: 0.7 # Higher temperature for creative problem scenarios
1313
provider: "ollama" # LLM provider
1414
model: "mistral-nemo:latest" # Model name
15-
save_as: "basic_prompt_topictree.jsonl"
15+
save_as: "programming_challenges_tree.jsonl"
1616

1717
data_engine:
1818
args:
@@ -27,7 +27,7 @@ data_engine:
2727
system_prompt: "<system_prompt_placeholder>" # Will be replaced with system_prompt
2828
provider: "ollama" # LLM provider
2929
model: "mistral-nemo:latest" # Model name
30-
temperature: 0.9 # Higher temperature for more creative variations
30+
temperature: 0.8 # Higher temperature for creative problem scenarios
3131
max_retries: 2 # Retry failed prompts up to 2 times
3232

3333
dataset:
@@ -36,4 +36,5 @@ dataset:
3636
batch_size: 1
3737
provider: "ollama" # LLM provider
3838
model: "mistral-nemo:latest" # Model name
39-
save_as: "basic_prompt_dataset.jsonl"
39+
sys_msg: true # Include system message in dataset (default: true)
40+
save_as: "programming_challenges.jsonl"

examples/example_with_hf.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ dataset:
2323
num_steps: 5
2424
batch_size: 1
2525
model_name: "ollama/mistral:latest"
26+
sys_msg: true # Include system message in dataset (default: true)
2627
save_as: "basic_prompt_dataset.jsonl"
2728

2829
# Hugging Face Hub configuration (optional)

promptwright/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def cli():
4747
multiple=True,
4848
help="Additional tags for the dataset (can be specified multiple times)",
4949
)
50+
@click.option(
51+
"--sys-msg",
52+
type=bool,
53+
help="Include system message in dataset (default: true)",
54+
)
5055
def start( # noqa: PLR0912
5156
config_file: str,
5257
topic_tree_save_as: str | None = None,
@@ -61,6 +66,7 @@ def start( # noqa: PLR0912
6166
hf_repo: str | None = None,
6267
hf_token: str | None = None,
6368
hf_tags: list[str] | None = None,
69+
sys_msg: bool | None = None,
6470
) -> None:
6571
"""Generate training data from a YAML configuration file."""
6672
try:
@@ -150,6 +156,7 @@ def start( # noqa: PLR0912
150156
batch_size=batch_size or dataset_params.get("batch_size", 1),
151157
topic_tree=tree,
152158
model_name=model_name,
159+
sys_msg=sys_msg, # Pass sys_msg to create_data
153160
)
154161
except Exception as e:
155162
handle_error(

promptwright/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,17 @@ def get_engine_args(self, **overrides) -> EngineArguments:
8686
# Construct full model string
8787
args["model_name"] = construct_model_string(provider, model)
8888

89+
# Get sys_msg from dataset config, defaulting to True
90+
dataset_config = self.get_dataset_config()
91+
sys_msg = dataset_config.get("creation", {}).get("sys_msg", True)
92+
8993
return EngineArguments(
9094
instructions=args.get("instructions", ""),
9195
system_prompt=args.get("system_prompt", ""),
9296
model_name=args["model_name"],
9397
temperature=args.get("temperature", 0.9),
9498
max_retries=args.get("max_retries", 2),
99+
sys_msg=sys_msg,
95100
)
96101

97102
def get_dataset_config(self) -> dict:

promptwright/engine.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ class EngineArguments:
5454
default_batch_size: int = 5
5555
default_num_examples: int = 3
5656
request_timeout: int = 30
57+
sys_msg: bool = True # Default to True for including system message
5758

5859

5960
class DataEngine:
@@ -81,7 +82,10 @@ def __init__(self, args: EngineArguments):
8182
"malformed_responses": [],
8283
"other_errors": [],
8384
}
84-
self.args.system_prompt = ENGINE_JSON_INSTRUCTIONS + self.args.system_prompt
85+
# Store original system prompt for dataset inclusion
86+
self.original_system_prompt = args.system_prompt
87+
# Use ENGINE_JSON_INSTRUCTIONS only for generation prompt
88+
self.generation_system_prompt = ENGINE_JSON_INSTRUCTIONS + args.system_prompt
8589

8690
def analyze_failure(self, response_content: str, error: Exception = None) -> str:
8791
"""Analyze the failure reason for a sample."""
@@ -134,6 +138,7 @@ def create_data( # noqa: PLR0912
134138
batch_size: int = 10,
135139
topic_tree: TopicTree = None,
136140
model_name: str = None,
141+
sys_msg: bool = None, # Allow overriding sys_msg from args
137142
):
138143
if num_steps is None:
139144
raise ValueError("num_steps must be specified") # noqa: TRY003
@@ -144,6 +149,9 @@ def create_data( # noqa: PLR0912
144149
if not self.model_name:
145150
raise ValueError("No valid model_name provided") # noqa: TRY003
146151

152+
# Use provided sys_msg or fall back to args.sys_msg
153+
include_sys_msg = sys_msg if sys_msg is not None else self.args.sys_msg
154+
147155
data_creation_prompt = SAMPLE_GENERATION_PROMPT
148156

149157
tree_paths = None
@@ -204,6 +212,17 @@ def create_data( # noqa: PLR0912
204212
response_content = r.choices[0].message.content
205213
parsed_json = validate_json_response(response_content)
206214

215+
if parsed_json and include_sys_msg: # noqa: SIM102
216+
# Add system message at the start if sys_msg is True
217+
if "messages" in parsed_json:
218+
parsed_json["messages"].insert(
219+
0,
220+
{
221+
"role": "system",
222+
"content": self.original_system_prompt,
223+
},
224+
)
225+
207226
if parsed_json:
208227
samples.append(parsed_json)
209228
else:
@@ -284,7 +303,7 @@ def build_prompt(
284303
subtopics_list: list[str] = None,
285304
) -> str:
286305
prompt = data_creation_prompt.replace(
287-
"{{{{system_prompt}}}}", self.build_system_prompt()
306+
"{{{{system_prompt}}}}", self.generation_system_prompt
288307
)
289308
prompt = prompt.replace(
290309
"{{{{instructions}}}}", self.build_custom_instructions_text()
@@ -297,7 +316,8 @@ def build_prompt(
297316
)
298317

299318
def build_system_prompt(self):
300-
return self.args.system_prompt
319+
"""Return the original system prompt for dataset inclusion."""
320+
return self.original_system_prompt
301321

302322
def build_custom_instructions_text(self) -> str:
303323
if self.args.instructions is None:

promptwright/hf_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def update_dataset_card(self, repo_id: str, tags: list[str] | None = None):
4848
try:
4949
card = DatasetCard.load(repo_id)
5050

51-
# Initialize tags if not a list
52-
if not isinstance(card.data.tags, list):
51+
# Initialize tags if not present
52+
if not hasattr(card.data, "tags") or not isinstance(card.data.tags, list):
5353
card.data.tags = []
5454

5555
# Add default promptwright tags

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "promptwright"
3-
version = "1.1.1"
3+
version = "1.2.1"
44
description = "LLM based Synthetic Data Generation"
55
authors = ["Luke Hinds <[email protected]>"]
66
readme = "README.md"

0 commit comments

Comments
 (0)