@@ -71,26 +71,18 @@ def forward(self, input):
71
71
res2 = self .predictor2 .batch ([input ] * 5 )
72
72
73
73
return (res1 , res2 )
74
-
75
- result , reason_result = MyModule ()(dspy .Example (input = "test input" ).with_inputs ("input" ))
76
74
77
- assert result [0 ].output == "test output 1"
78
- assert result [1 ].output == "test output 2"
79
- assert result [2 ].output == "test output 3"
80
- assert result [3 ].output == "test output 4"
81
- assert result [4 ].output == "test output 5"
75
+ result , reason_result = MyModule ()(dspy .Example (input = "test input" ).with_inputs ("input" ))
82
76
83
- assert reason_result [0 ].output == "test output 1"
84
- assert reason_result [1 ].output == "test output 2"
85
- assert reason_result [2 ].output == "test output 3"
86
- assert reason_result [3 ].output == "test output 4"
87
- assert reason_result [4 ].output == "test output 5"
77
+ # Check that we got all expected outputs without caring about order
78
+ expected_outputs = {f"test output { i } " for i in range (1 , 6 )}
79
+ assert {r .output for r in result } == expected_outputs
80
+ assert {r .output for r in reason_result } == expected_outputs
88
81
89
- assert reason_result [0 ].reasoning == "test reasoning 1"
90
- assert reason_result [1 ].reasoning == "test reasoning 2"
91
- assert reason_result [2 ].reasoning == "test reasoning 3"
92
- assert reason_result [3 ].reasoning == "test reasoning 4"
93
- assert reason_result [4 ].reasoning == "test reasoning 5"
82
+ # Check that reasoning matches outputs for reason_result
83
+ for r in reason_result :
84
+ num = r .output .split ()[- 1 ] # get the number from "test output X"
85
+ assert r .reasoning == f"test reasoning { num } "
94
86
95
87
96
88
def test_nested_parallel_module ():
@@ -120,7 +112,7 @@ def forward(self, input):
120
112
(self .predictor , input ),
121
113
]),
122
114
])
123
-
115
+
124
116
output = MyModule ()(dspy .Example (input = "test input" ).with_inputs ("input" ))
125
117
126
118
assert output [0 ].output == "test output 1"
@@ -148,7 +140,7 @@ def forward(self, input):
148
140
res = self .predictor .batch ([dspy .Example (input = input ).with_inputs ("input" )]* 2 )
149
141
150
142
return res
151
-
143
+
152
144
result = MyModule ().batch ([dspy .Example (input = "test input" ).with_inputs ("input" )]* 2 )
153
145
154
146
assert {result [0 ][0 ].output , result [0 ][1 ].output , result [1 ][0 ].output , result [1 ][1 ].output } \
0 commit comments