Skip to content

Commit ebd012a

Browse files
Update generated stream wrappers
1 parent 3e53341 commit ebd012a

File tree

3 files changed

+42
-35
lines changed

3 files changed

+42
-35
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

+37-5
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ def _classify_error(
200200
if (hasStreaming) {
201201
writer.addStdlibImports("typing", Set.of("Any", "Awaitable"));
202202
writer.addStdlibImport("asyncio");
203+
204+
writer.addImports("smithy_core.aio.eventstream",
205+
Set.of(
206+
"InputEventStream",
207+
"OutputEventStream",
208+
"DuplexEventStream"));
209+
writer.addImport("smithy_core.aio.interfaces.eventstream", "EventReceiver");
203210
writer.write(
204211
"""
205212
async def _input_stream[Input: SerializeableShape, Output: DeserializeableShape](
@@ -218,6 +225,10 @@ def _classify_error(
218225
))
219226
request_context = await request_future
220227
${5C|}
228+
return InputEventStream[Any, Any](
229+
input_stream=publisher,
230+
output_future=awaitable_output,
231+
)
221232
222233
async def _output_stream[Input: SerializeableShape, Output: DeserializeableShape](
223234
self,
@@ -236,6 +247,10 @@ def _classify_error(
236247
)
237248
transport_response = await response_future
238249
${6C|}
250+
return OutputEventStream[Any, Any](
251+
output_stream=receiver,
252+
output=output
253+
)
239254
240255
async def _duplex_stream[Input: SerializeableShape, Output: DeserializeableShape](
241256
self,
@@ -255,15 +270,34 @@ def _classify_error(
255270
response_future=response_future
256271
))
257272
request_context = await request_future
258-
${7C|}
273+
${5C|}
274+
output_future = asyncio.create_task(self._wrap_duplex_output(
275+
response_future, awaitable_output, config, operation_name,
276+
event_deserializer
277+
))
278+
return DuplexEventStream[Any, Any, Any](
279+
input_stream=publisher,
280+
output_future=output_future,
281+
)
282+
283+
async def _wrap_duplex_output(
284+
self,
285+
response_future: Future[$3T],
286+
awaitable_output: Future[Any],
287+
config: $4T,
288+
operation_name: str,
289+
event_deserializer: Callable[[ShapeDeserializer], Any],
290+
) -> tuple[Any, EventReceiver[Any]]:
291+
transport_response = await response_future
292+
${6C|}
293+
return await awaitable_output, receiver
259294
""",
260295
pluginSymbol,
261296
transportRequest,
262297
transportResponse,
263298
configSymbol,
264299
writer.consumer(w -> context.protocolGenerator().wrapInputStream(context, w)),
265-
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)),
266-
writer.consumer(w -> context.protocolGenerator().wrapDuplexStream(context, w)));
300+
writer.consumer(w -> context.protocolGenerator().wrapOutputStream(context, w)));
267301
}
268302
writer.addStdlibImport("typing", "Any");
269303
writer.addStdlibImport("asyncio", "iscoroutine");
@@ -872,7 +906,6 @@ private void generateEventStreamOperation(PythonWriter writer, OperationShape op
872906

873907
if (inputStreamSymbol != null) {
874908
if (outputStreamSymbol != null) {
875-
writer.addImport("smithy_event_stream.aio.interfaces", "DuplexEventStream");
876909
writer.write("""
877910
async def ${operationName:L}(
878911
self,
@@ -922,7 +955,6 @@ raise NotImplementedError()
922955
""", writer.consumer(w -> writeSharedOperationInit(w, operation, input)));
923956
}
924957
} else {
925-
writer.addImport("smithy_event_stream.aio.interfaces", "OutputEventStream");
926958
writer.write("""
927959
async def ${operationName:L}(
928960
self,

codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/ProtocolGenerator.java

-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,4 @@ default void generateProtocolTests(GenerationContext context) {}
157157
default void wrapInputStream(GenerationContext context, PythonWriter writer) {}
158158

159159
default void wrapOutputStream(GenerationContext context, PythonWriter writer) {}
160-
161-
default void wrapDuplexStream(GenerationContext context, PythonWriter writer) {}
162160
}

codegen/core/src/main/java/software/amazon/smithy/python/codegen/integrations/RestJsonProtocolGenerator.java

+5-28
Original file line numberDiff line numberDiff line change
@@ -396,13 +396,12 @@ public void wrapInputStream(GenerationContext context, PythonWriter writer) {
396396
writer.addImport("smithy_json", "JSONCodec");
397397
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
398398
writer.addImport("smithy_core.types", "TimestampFormat");
399-
writer.addImport("aws_event_stream.aio", "AWSInputEventStream");
399+
writer.addImport("aws_event_stream.aio", "AWSEventPublisher");
400400
writer.write(
401401
"""
402402
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
403-
return AWSInputEventStream[Any, Any](
403+
publisher = AWSEventPublisher[Any](
404404
payload_codec=codec,
405-
awaitable_output=awaitable_output,
406405
async_writer=request_context.transport_request.body, # type: ignore
407406
)
408407
""");
@@ -415,39 +414,17 @@ public void wrapOutputStream(GenerationContext context, PythonWriter writer) {
415414
writer.addImport("smithy_json", "JSONCodec");
416415
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
417416
writer.addImport("smithy_core.types", "TimestampFormat");
418-
writer.addImport("aws_event_stream.aio", "AWSOutputEventStream");
417+
writer.addImport("aws_event_stream.aio", "AWSEventReceiver");
419418
writer.write(
420419
"""
421420
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
422-
return AWSOutputEventStream[Any, Any](
421+
receiver = AWSEventReceiver(
423422
payload_codec=codec,
424-
initial_response=output,
425-
async_reader=AsyncBytesReader(
423+
source=AsyncBytesReader(
426424
transport_response.body # type: ignore
427425
),
428426
deserializer=event_deserializer, # type: ignore
429427
)
430428
""");
431429
}
432-
433-
@Override
434-
public void wrapDuplexStream(GenerationContext context, PythonWriter writer) {
435-
writer.addDependency(SmithyPythonDependency.SMITHY_JSON);
436-
writer.addDependency(SmithyPythonDependency.AWS_EVENT_STREAM);
437-
writer.addImport("smithy_json", "JSONCodec");
438-
writer.addImport("smithy_core.aio.types", "AsyncBytesReader");
439-
writer.addImport("smithy_core.types", "TimestampFormat");
440-
writer.addImport("aws_event_stream.aio", "AWSDuplexEventStream");
441-
writer.write(
442-
"""
443-
codec = JSONCodec(default_timestamp_format=TimestampFormat.EPOCH_SECONDS)
444-
return AWSDuplexEventStream[Any, Any, Any](
445-
payload_codec=codec,
446-
async_writer=request_context.transport_request.body, # type: ignore
447-
awaitable_output=awaitable_output,
448-
awaitable_response=response_future,
449-
deserializer=event_deserializer, # type: ignore
450-
)
451-
""");
452-
}
453430
}

0 commit comments

Comments
 (0)