-
Notifications
You must be signed in to change notification settings - Fork 58
feat: add allow_resize for 1:N and N:1 generation patterns #286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8d2fdce
7b31cea
399eb4a
1a16906
07d14bd
0ff511d
ab8ce11
8285b48
4394e76
a163673
ede1bd9
3c565ef
f2b4f2e
1e82937
fd2675f
c471a02
401ef32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -93,6 +93,58 @@ This gives you direct access to all `ModelFacade` capabilities: custom parsers, | |
| | `generator_function` | Callable | Yes | Decorated function | | ||
| | `generation_strategy` | GenerationStrategy | No | `CELL_BY_CELL` or `FULL_COLUMN` | | ||
| | `generator_params` | BaseModel | No | Typed params passed to function | | ||
| | `allow_resize` | bool | No | Allow 1:N or N:1 generation | | ||
|
|
||
| ### Resizing (1:N and N:1) | ||
|
|
||
| **FULL_COLUMN:** Set `allow_resize=True` and return a DataFrame with more or fewer rows than the input: | ||
|
|
||
| ```python | ||
| @dd.custom_column_generator( | ||
| required_columns=["topic"], | ||
| side_effect_columns=["variation_id"], | ||
| ) | ||
| def expand_topics(df: pd.DataFrame, params: None, models: dict) -> pd.DataFrame: | ||
| rows = [] | ||
| for _, row in df.iterrows(): | ||
| for i in range(3): # Generate 3 variations per input | ||
| rows.append({ | ||
| "topic": row["topic"], | ||
| "question": f"Question {i+1} about {row['topic']}", | ||
| "variation_id": i, | ||
| }) | ||
| return pd.DataFrame(rows) | ||
|
|
||
| dd.CustomColumnConfig( | ||
| name="question", | ||
| generator_function=expand_topics, | ||
| generation_strategy=dd.GenerationStrategy.FULL_COLUMN, | ||
| allow_resize=True, | ||
| ) | ||
| ``` | ||
|
|
||
| **CELL_BY_CELL:** With `allow_resize=True`, your function may return a single row (`dict`) or multiple rows (`list[dict]`). Return `[]` to drop that input row. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wasn't doing that at first, but I think it's useful, especially since functions with LLM calls will typically use the cell-by-cell strategy. |
||
|
|
||
| ```python | ||
| @dd.custom_column_generator(required_columns=["id"]) | ||
| def expand_row(row: dict) -> list[dict]: | ||
| return [ | ||
| {**row, "variant": "a"}, | ||
| {**row, "variant": "b"}, | ||
| ] | ||
|
|
||
| dd.CustomColumnConfig( | ||
| name="variant", | ||
| generator_function=expand_row, | ||
| generation_strategy=dd.GenerationStrategy.CELL_BY_CELL, | ||
| allow_resize=True, | ||
| ) | ||
| ``` | ||
|
|
||
| Use cases: | ||
|
|
||
| - **Expansion (1:N)**: Generate multiple variations per input | ||
| - **Retraction (N:1)**: Filter, aggregate, or deduplicate records (FULL_COLUMN) or return `[]` per row (CELL_BY_CELL) | ||
|
|
||
| ## Multi-Turn Example | ||
|
|
||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Example will be removed later as usual, just for reference
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NB this example uses CustomColumns but plugins are also affected by this. I've added a comment about it on docs. I've also implemented a plugin locally to test it, seems to work fine (lots of LoCs so not adding to this PR) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Example: Chaining expand -> retract -> expand resize operations. | ||
|
|
||
| Pipeline: 5 topics -> 15 questions (3 per topic) -> ~8 hard questions (filter easy) | ||
| -> ~24 answer variants (3 per question) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import data_designer.config as dd | ||
| from data_designer.interface import DataDesigner | ||
| from data_designer.lazy_heavy_imports import pd | ||
|
|
||
|
|
||
| # Step 1: Expand — 1:N, generate 3 questions per topic | ||
| @dd.custom_column_generator(required_columns=["topic"], side_effect_columns=["question_id", "difficulty"]) | ||
| def expand_to_questions(df: pd.DataFrame) -> pd.DataFrame: | ||
| rows = [] | ||
| for _, row in df.iterrows(): | ||
| for i in range(3): | ||
| rows.append( | ||
| { | ||
| "topic": row["topic"], | ||
| "question": f"Q{i + 1} about {row['topic']}?", | ||
| "question_id": i, | ||
| "difficulty": ["easy", "medium", "hard"][i], | ||
| } | ||
| ) | ||
| return pd.DataFrame(rows) | ||
|
|
||
|
|
||
| # Step 2: Retract — N:1, keep only medium/hard questions | ||
| @dd.custom_column_generator(required_columns=["difficulty"]) | ||
| def filter_non_easy(df: pd.DataFrame) -> pd.DataFrame: | ||
| return df[df["difficulty"] != "easy"].copy().assign(filtered=True) | ||
|
|
||
|
|
||
| # Step 3: Expand again — 1:N, generate 3 answer variants per surviving question | ||
| @dd.custom_column_generator(required_columns=["question"], side_effect_columns=["variant"]) | ||
| def expand_to_answers(df: pd.DataFrame) -> pd.DataFrame: | ||
| rows = [] | ||
| for _, row in df.iterrows(): | ||
| for v in range(3): | ||
| rows.append({**row.to_dict(), "answer": f"Answer v{v} to: {row['question']}", "variant": v}) | ||
| return pd.DataFrame(rows) | ||
|
|
||
|
|
||
| def main() -> None: | ||
| data_designer = DataDesigner() | ||
| config_builder = dd.DataDesignerConfigBuilder() | ||
|
|
||
| # Seed: 5 topics | ||
| config_builder.add_column( | ||
| dd.SamplerColumnConfig( | ||
| name="topic", | ||
| sampler_type=dd.SamplerType.CATEGORY, | ||
| params=dd.CategorySamplerParams(values=["Python", "ML", "Data", "Stats", "SQL"]), | ||
| ) | ||
| ) | ||
|
|
||
| # Expand: 5 topics -> 15 questions | ||
| config_builder.add_column( | ||
| dd.CustomColumnConfig( | ||
| name="question", | ||
| generator_function=expand_to_questions, | ||
| generation_strategy=dd.GenerationStrategy.FULL_COLUMN, | ||
| allow_resize=True, | ||
| ) | ||
| ) | ||
|
|
||
| # Retract: 15 -> 10 (drop "easy" questions) | ||
| config_builder.add_column( | ||
| dd.CustomColumnConfig( | ||
| name="filtered", | ||
| generator_function=filter_non_easy, | ||
| generation_strategy=dd.GenerationStrategy.FULL_COLUMN, | ||
| allow_resize=True, | ||
| ) | ||
| ) | ||
|
|
||
| # Expand again: 10 -> 30 answer variants | ||
| config_builder.add_column( | ||
| dd.CustomColumnConfig( | ||
| name="answer", | ||
| generator_function=expand_to_answers, | ||
| generation_strategy=dd.GenerationStrategy.FULL_COLUMN, | ||
| allow_resize=True, | ||
| ) | ||
| ) | ||
|
|
||
| # Preview (single batch) | ||
| preview = data_designer.preview(config_builder=config_builder, num_records=5) | ||
| print(f"Preview: 5 topics -> {len(preview.dataset)} answer variants") | ||
| print(preview.dataset[["topic", "difficulty", "question", "variant", "answer"]].to_string()) | ||
| print() | ||
|
|
||
| # Build (multiple batches: 10 records with buffer_size=3 -> 4 batches) | ||
| data_designer.set_run_config(dd.RunConfig(buffer_size=3)) | ||
| results = data_designer.create(config_builder=config_builder, num_records=10) | ||
| df = results.load_dataset() | ||
| print(f"Build: 10 topics (4 batches of 3+3+3+1) -> {len(df)} answer variants") | ||
| print(df[["topic", "difficulty", "question", "variant"]].to_string()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -517,19 +517,24 @@ def test_sampler_column_config_discriminated_union_wrong_params_type(): | |
| ) | ||
|
|
||
|
|
||
| def test_default_column_emoji_for_custom_column_type() -> None: | ||
| """Ensure the base get_column_emoji implementation is used when not overridden.""" | ||
| class StubColumnConfig(SingleColumnConfig): | ||
| column_type: Literal["stub"] = "stub" | ||
|
|
||
| @property | ||
| def required_columns(self) -> list[str]: | ||
| return [] | ||
|
|
||
| class StubColumnConfigWithoutEmoji(SingleColumnConfig): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Moving this since the test for |
||
| column_type: Literal["stub-without-emoji"] = "stub-without-emoji" | ||
| value: str | ||
| @property | ||
| def side_effect_columns(self) -> list[str]: | ||
| return [] | ||
|
|
||
| @property | ||
| def required_columns(self) -> list[str]: | ||
| return [] | ||
|
|
||
| @property | ||
| def side_effect_columns(self) -> list[str]: | ||
| return [] | ||
| def test_default_column_emoji_for_custom_column_type() -> None: | ||
| """Ensure the base get_column_emoji implementation is used when not overridden.""" | ||
| assert StubColumnConfig.get_column_emoji() == "🎨" | ||
|
|
||
|
|
||
| assert StubColumnConfigWithoutEmoji.get_column_emoji() == "🎨" | ||
| def test_allow_resize_inherited_by_subclasses() -> None: | ||
| """Subclasses inherit allow_resize from SingleColumnConfig.""" | ||
| assert StubColumnConfig(name="test").allow_resize is False | ||
| assert StubColumnConfig(name="test", allow_resize=True).allow_resize is True | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid parameter name in doc example
The
@custom_column_generatordecorator validates parameter names at decoration time and requires the second parameter to be namedgenerator_params, notparams. This example will raiseTypeError: param 2 must be 'generator_params'if a user copies and runs it (confirmed by the test attest_custom.py:427).Prompt To Fix With AI