Skip to content

Commit e081d99

Browse files
RobertCraigiestainless-app[bot]
authored andcommitted
Revert "chore(internal): streaming refactors (#2012)"
This reverts commit d76a748.
1 parent 82ccc98 commit e081d99

File tree

1 file changed

+72
-32
lines changed

1 file changed

+72
-32
lines changed

src/openai/_streaming.py

+72-32
Original file line numberDiff line numberDiff line change
@@ -59,22 +59,42 @@ def __stream__(self) -> Iterator[_T]:
5959
if sse.data.startswith("[DONE]"):
6060
break
6161

62-
data = sse.json()
63-
if is_mapping(data) and data.get("error"):
64-
message = None
65-
error = data.get("error")
66-
if is_mapping(error):
67-
message = error.get("message")
68-
if not message or not isinstance(message, str):
69-
message = "An error occurred during streaming"
70-
71-
raise APIError(
72-
message=message,
73-
request=self.response.request,
74-
body=data["error"],
75-
)
76-
77-
yield process_data(data=data, cast_to=cast_to, response=response)
62+
if sse.event is None:
63+
data = sse.json()
64+
if is_mapping(data) and data.get("error"):
65+
message = None
66+
error = data.get("error")
67+
if is_mapping(error):
68+
message = error.get("message")
69+
if not message or not isinstance(message, str):
70+
message = "An error occurred during streaming"
71+
72+
raise APIError(
73+
message=message,
74+
request=self.response.request,
75+
body=data["error"],
76+
)
77+
78+
yield process_data(data=data, cast_to=cast_to, response=response)
79+
80+
else:
81+
data = sse.json()
82+
83+
if sse.event == "error" and is_mapping(data) and data.get("error"):
84+
message = None
85+
error = data.get("error")
86+
if is_mapping(error):
87+
message = error.get("message")
88+
if not message or not isinstance(message, str):
89+
message = "An error occurred during streaming"
90+
91+
raise APIError(
92+
message=message,
93+
request=self.response.request,
94+
body=data["error"],
95+
)
96+
97+
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
7898

7999
# Ensure the entire stream is consumed
80100
for _sse in iterator:
@@ -141,22 +161,42 @@ async def __stream__(self) -> AsyncIterator[_T]:
141161
if sse.data.startswith("[DONE]"):
142162
break
143163

144-
data = sse.json()
145-
if is_mapping(data) and data.get("error"):
146-
message = None
147-
error = data.get("error")
148-
if is_mapping(error):
149-
message = error.get("message")
150-
if not message or not isinstance(message, str):
151-
message = "An error occurred during streaming"
152-
153-
raise APIError(
154-
message=message,
155-
request=self.response.request,
156-
body=data["error"],
157-
)
158-
159-
yield process_data(data=data, cast_to=cast_to, response=response)
164+
if sse.event is None:
165+
data = sse.json()
166+
if is_mapping(data) and data.get("error"):
167+
message = None
168+
error = data.get("error")
169+
if is_mapping(error):
170+
message = error.get("message")
171+
if not message or not isinstance(message, str):
172+
message = "An error occurred during streaming"
173+
174+
raise APIError(
175+
message=message,
176+
request=self.response.request,
177+
body=data["error"],
178+
)
179+
180+
yield process_data(data=data, cast_to=cast_to, response=response)
181+
182+
else:
183+
data = sse.json()
184+
185+
if sse.event == "error" and is_mapping(data) and data.get("error"):
186+
message = None
187+
error = data.get("error")
188+
if is_mapping(error):
189+
message = error.get("message")
190+
if not message or not isinstance(message, str):
191+
message = "An error occurred during streaming"
192+
193+
raise APIError(
194+
message=message,
195+
request=self.response.request,
196+
body=data["error"],
197+
)
198+
199+
yield process_data(data={"data": data, "event": sse.event}, cast_to=cast_to, response=response)
160200

161201
# Ensure the entire stream is consumed
162202
async for _sse in iterator:

0 commit comments

Comments
 (0)