Skip to content

Commit 73cf640

Browse files
authored
Fix signature field order mismatch (#7875)
* Fix signature field order mismatch * WIP update batch test
1 parent 8f17c16 commit 73cf640

File tree

3 files changed

+29
-20
lines changed

3 files changed

+29
-20
lines changed

dspy/signatures/signature.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,20 @@ def __call__(cls, *args, **kwargs): # noqa: ANN002
4545
return super().__call__(*args, **kwargs)
4646

4747
def __new__(mcs, signature_name, bases, namespace, **kwargs): # noqa: N804
48+
# At this point, the orders have been swapped already.
49+
field_order = [name for name, value in namespace.items() if isinstance(value, FieldInfo)]
4850
# Set `str` as the default type for all fields
4951
raw_annotations = namespace.get("__annotations__", {})
5052
for name, field in namespace.items():
5153
if not isinstance(field, FieldInfo):
5254
continue # Don't add types to non-field attributes
5355
if not name.startswith("__") and name not in raw_annotations:
5456
raw_annotations[name] = str
55-
namespace["__annotations__"] = raw_annotations
57+
# Create ordered annotations dictionary that preserves field order
58+
ordered_annotations = {name: raw_annotations[name] for name in field_order if name in raw_annotations}
59+
# Add any remaining annotations that weren't in field_order
60+
ordered_annotations.update({k: v for k, v in raw_annotations.items() if k not in ordered_annotations})
61+
namespace["__annotations__"] = ordered_annotations
5662

5763
# Let Pydantic do its thing
5864
cls = super().__new__(mcs, signature_name, bases, namespace, **kwargs)

tests/predict/test_parallel.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,26 +71,18 @@ def forward(self, input):
7171
res2 = self.predictor2.batch([input] * 5)
7272

7373
return (res1, res2)
74-
75-
result, reason_result = MyModule()(dspy.Example(input="test input").with_inputs("input"))
7674

77-
assert result[0].output == "test output 1"
78-
assert result[1].output == "test output 2"
79-
assert result[2].output == "test output 3"
80-
assert result[3].output == "test output 4"
81-
assert result[4].output == "test output 5"
75+
result, reason_result = MyModule()(dspy.Example(input="test input").with_inputs("input"))
8276

83-
assert reason_result[0].output == "test output 1"
84-
assert reason_result[1].output == "test output 2"
85-
assert reason_result[2].output == "test output 3"
86-
assert reason_result[3].output == "test output 4"
87-
assert reason_result[4].output == "test output 5"
77+
# Check that we got all expected outputs without caring about order
78+
expected_outputs = {f"test output {i}" for i in range(1, 6)}
79+
assert {r.output for r in result} == expected_outputs
80+
assert {r.output for r in reason_result} == expected_outputs
8881

89-
assert reason_result[0].reasoning == "test reasoning 1"
90-
assert reason_result[1].reasoning == "test reasoning 2"
91-
assert reason_result[2].reasoning == "test reasoning 3"
92-
assert reason_result[3].reasoning == "test reasoning 4"
93-
assert reason_result[4].reasoning == "test reasoning 5"
82+
# Check that reasoning matches outputs for reason_result
83+
for r in reason_result:
84+
num = r.output.split()[-1] # get the number from "test output X"
85+
assert r.reasoning == f"test reasoning {num}"
9486

9587

9688
def test_nested_parallel_module():
@@ -120,7 +112,7 @@ def forward(self, input):
120112
(self.predictor, input),
121113
]),
122114
])
123-
115+
124116
output = MyModule()(dspy.Example(input="test input").with_inputs("input"))
125117

126118
assert output[0].output == "test output 1"
@@ -148,7 +140,7 @@ def forward(self, input):
148140
res = self.predictor.batch([dspy.Example(input=input).with_inputs("input")]*2)
149141

150142
return res
151-
143+
152144
result = MyModule().batch([dspy.Example(input="test input").with_inputs("input")]*2)
153145

154146
assert {result[0][0].output, result[0][1].output, result[1][0].output, result[1][1].output} \

tests/signatures/test_signature.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,17 @@ class InitialSignature(Signature):
151151
assert "new_output_end" == list(S4.output_fields.keys())[-1]
152152

153153

154+
def test_order_preserved_with_mixed_annotations():
155+
class ExampleSignature(dspy.Signature):
156+
text: str = dspy.InputField()
157+
output = dspy.OutputField()
158+
pass_evaluation: bool = dspy.OutputField()
159+
160+
expected_order = ["text", "output", "pass_evaluation"]
161+
actual_order = list(ExampleSignature.fields.keys())
162+
assert actual_order == expected_order
163+
164+
154165
def test_infer_prefix():
155166
assert infer_prefix("someAttributeName42IsCool") == "Some Attribute Name 42 Is Cool"
156167
assert infer_prefix("version2Update") == "Version 2 Update"

0 commit comments

Comments
 (0)