Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 79 additions & 12 deletions dspy/teleprompt/simba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
logger = logging.getLogger(__name__)

def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings: dict | None = None):
"""Prepares a list of language models for resampling by assigning unique rollout IDs.

Creates n models with sequential rollout IDs. If teacher_settings is provided, the first
model uses the teacher's language model configuration. Remaining models are copies of the
base model with temperature set to 1.0.

Returns:
A list of language models configured for resampling with unique rollout IDs.
"""
lm = program.get_lm() or dspy.settings.lm

start_rollout_id = lm.kwargs.get("rollout_id", 0)
Expand All @@ -32,7 +41,26 @@ def prepare_models_for_resampling(program: dspy.Module, n: int, teacher_settings
return models

def wrap_program(program: dspy.Module, metric: Callable):
"""Wraps a program to capture its execution trace and evaluate it with a metric.

Returns a function that executes the program on an example, captures the trace,
evaluates the prediction using the metric, and returns a dictionary containing
the prediction, trace, score, example, and any additional metadata from the metric.
The metric can return a numeric score or a dspy.Prediction with a score field.

Returns:
A function that takes an example and returns a dictionary with prediction results,
trace, score, and metadata.
"""
def wrapped_program(example):
"""Executes the program on an example and captures its trace.

Runs the program with the given example, captures the execution trace, evaluates
the result using the metric, and packages everything into a result dictionary.

Returns:
A dictionary containing prediction, trace, score, example, and output_metadata.
"""
with dspy.context(trace=[]):
prediction, trace, score = None, None, 0.0
try:
Expand Down Expand Up @@ -71,7 +99,25 @@ def wrapped_program(example):
return wrapped_program

def append_a_demo(demo_input_field_maxlen):
"""Returns a function that appends demonstrations from a successful trajectory to predictors.

The returned function extracts demonstrations from the best trajectory in a bucket and
appends them to the corresponding predictors. Input fields longer than demo_input_field_maxlen
are truncated. Skips appending if the best score is at or below the 10th percentile.

Returns:
A function that processes a bucket and appends demonstrations to predictors.
"""
def append_a_demo_(bucket, system, **kwargs):
"""Extracts and appends demonstrations from the best trajectory to predictors.

Processes the highest-scoring trajectory in the bucket, creates demonstrations from
each step, and appends them to the corresponding predictors. Truncates long input
fields and skips if the score is too low.

Returns:
True if demonstrations were appended, False if skipped due to low score.
"""
predictor2name, name2predictor = kwargs["predictor2name"], kwargs["name2predictor"]
batch_10p_score = kwargs["batch_10p_score"]

Expand Down Expand Up @@ -104,6 +150,16 @@ def append_a_demo_(bucket, system, **kwargs):


def append_a_rule(bucket, system, **kwargs):
"""Generates and appends advice to predictor instructions by comparing good and bad trajectories.

Uses a language model to analyze the difference between a high-scoring and low-scoring
trajectory, generating module-specific advice. The advice is appended to each predictor's
instructions. Skips rule generation if the good score is too low or the bad score is too high
relative to batch percentiles.

Returns:
True if advice was generated and appended, False if skipped due to score thresholds.
"""
predictor2name = kwargs["predictor2name"]
batch_10p_score, batch_90p_score = kwargs["batch_10p_score"], kwargs["batch_90p_score"]
prompt_model = kwargs["prompt_model"] or dspy.settings.lm
Expand Down Expand Up @@ -168,18 +224,11 @@ def append_a_rule(bucket, system, **kwargs):
return True

class OfferFeedback(dspy.Signature):
"""
You will be given two trajectories of an LLM-driven program's execution. Your goal is to help the program's modules
build up experience on how to maximize the reward value assigned to the program's outputs if it were to receive
similar inputs in the future.

The module won't see its own history. It will rely on your advice balancing being concrete and being generalizable.

In your advice:
- Avoid boilerplate. Offer advice that would change the module's behavior for the better in the future.
- Ensure that advice offered to a module M is specific to that M's specific sub-task, not the overall program.
- Rely on contrasting the behavior of the worse trajectory against the better trajectory in making recommendations.
- Ensure each unique module name appears exactly once as a key in the advice dictionary.
"""Signature for generating module-specific advice by comparing successful and unsuccessful trajectories.

Analyzes two program execution trajectories with different reward values to generate
concrete, actionable advice for each module. The advice helps modules improve their
behavior by learning from the contrast between better and worse trajectories.
"""

program_code: str = InputField(desc="The code of the program that we are analyzing")
Expand Down Expand Up @@ -208,6 +257,15 @@ class OfferFeedback(dspy.Signature):
)

def inspect_modules(program):
"""Formats module information into a human-readable string representation.

Extracts and formats each predictor's name, input fields, output fields, and instructions
into a structured text format with separators. The output is suitable for inclusion in
prompts or logs.

Returns:
A formatted string containing module definitions with their fields and instructions.
"""
separator = "-" * 80
output = [separator]

Expand All @@ -228,6 +286,15 @@ def inspect_modules(program):


def recursive_mask(o):
"""Recursively masks non-serializable objects with placeholder strings.

Traverses the object structure and replaces any non-JSON-serializable values with
a placeholder string indicating the type. Handles dictionaries, lists, and tuples
recursively while preserving already-serializable values.

Returns:
The object with non-serializable values replaced by placeholder strings.
"""
# If the object is already serializable, return it.
try:
orjson.dumps(o)
Expand Down