@@ -200,6 +200,13 @@ def _classify_error(
200
200
if (hasStreaming ) {
201
201
writer .addStdlibImports ("typing" , Set .of ("Any" , "Awaitable" ));
202
202
writer .addStdlibImport ("asyncio" );
203
+
204
+ writer .addImports ("smithy_core.aio.eventstream" ,
205
+ Set .of (
206
+ "InputEventStream" ,
207
+ "OutputEventStream" ,
208
+ "DuplexEventStream" ));
209
+ writer .addImport ("smithy_core.aio.interfaces.eventstream" , "EventReceiver" );
203
210
writer .write (
204
211
"""
205
212
async def _input_stream[Input: SerializeableShape, Output: DeserializeableShape](
@@ -218,6 +225,10 @@ def _classify_error(
218
225
))
219
226
request_context = await request_future
220
227
${5C|}
228
+ return InputEventStream[Any, Any](
229
+ input_stream=publisher,
230
+ output_future=awaitable_output,
231
+ )
221
232
222
233
async def _output_stream[Input: SerializeableShape, Output: DeserializeableShape](
223
234
self,
@@ -236,6 +247,10 @@ def _classify_error(
236
247
)
237
248
transport_response = await response_future
238
249
${6C|}
250
+ return OutputEventStream[Any, Any](
251
+ output_stream=receiver,
252
+ output=output
253
+ )
239
254
240
255
async def _duplex_stream[Input: SerializeableShape, Output: DeserializeableShape](
241
256
self,
@@ -255,15 +270,34 @@ def _classify_error(
255
270
response_future=response_future
256
271
))
257
272
request_context = await request_future
258
- ${7C|}
273
+ ${5C|}
274
+ output_future = asyncio.create_task(self._wrap_duplex_output(
275
+ response_future, awaitable_output, config, operation_name,
276
+ event_deserializer
277
+ ))
278
+ return DuplexEventStream[Any, Any, Any](
279
+ input_stream=publisher,
280
+ output_future=output_future,
281
+ )
282
+
283
+ async def _wrap_duplex_output(
284
+ self,
285
+ response_future: Future[$3T],
286
+ awaitable_output: Future[Any],
287
+ config: $4T,
288
+ operation_name: str,
289
+ event_deserializer: Callable[[ShapeDeserializer], Any],
290
+ ) -> tuple[Any, EventReceiver[Any]]:
291
+ transport_response = await response_future
292
+ ${6C|}
293
+ return await awaitable_output, receiver
259
294
""" ,
260
295
pluginSymbol ,
261
296
transportRequest ,
262
297
transportResponse ,
263
298
configSymbol ,
264
299
writer .consumer (w -> context .protocolGenerator ().wrapInputStream (context , w )),
265
- writer .consumer (w -> context .protocolGenerator ().wrapOutputStream (context , w )),
266
- writer .consumer (w -> context .protocolGenerator ().wrapDuplexStream (context , w )));
300
+ writer .consumer (w -> context .protocolGenerator ().wrapOutputStream (context , w )));
267
301
}
268
302
writer .addStdlibImport ("typing" , "Any" );
269
303
writer .addStdlibImport ("asyncio" , "iscoroutine" );
@@ -872,7 +906,6 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op
872
906
873
907
if (inputStreamSymbol != null ) {
874
908
if (outputStreamSymbol != null ) {
875
- writer .addImport ("smithy_event_stream.aio.interfaces" , "DuplexEventStream" );
876
909
writer .write ("""
877
910
async def ${operationName:L}(
878
911
self,
@@ -922,7 +955,6 @@ raise NotImplementedError()
922
955
""" , writer .consumer (w -> writeSharedOperationInit (w , operation , input )));
923
956
}
924
957
} else {
925
- writer .addImport ("smithy_event_stream.aio.interfaces" , "OutputEventStream" );
926
958
writer .write ("""
927
959
async def ${operationName:L}(
928
960
self,
0 commit comments