Skip to content
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/agents/voice/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,25 +243,50 @@ async def _wait_for_completion(self):
tasks.append(self._dispatcher_task)
await asyncio.gather(*tasks)

def _cleanup_tasks(self):
async def _cleanup_tasks(self):
"""Cancel all pending tasks and wait for them to complete.

This ensures that any exceptions raised by the tasks are properly handled
and prevents warnings about unhandled task exceptions.
"""
self._finish_turn()

tasks = []
for task in self._tasks:
if not task.done():
task.cancel()
if isinstance(task, asyncio.Task):
tasks.append(task)

if self._dispatcher_task and not self._dispatcher_task.done():
self._dispatcher_task.cancel()
if isinstance(self._dispatcher_task, asyncio.Task):
tasks.append(self._dispatcher_task)

if self.text_generation_task and not self.text_generation_task.done():
self.text_generation_task.cancel()
if isinstance(self.text_generation_task, asyncio.Task):
tasks.append(self.text_generation_task)

# Wait for all cancelled tasks to complete and collect exceptions
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)

def _check_errors(self):
"""Check for exceptions in completed tasks.

Note: CancelledError is not checked as it's expected during cleanup.
"""
for task in self._tasks:
if task.done():
if task.exception():
self._stored_exception = task.exception()
break
if task.done() and not task.cancelled():
try:
exc = task.exception()
if exc:
self._stored_exception = exc
break
except asyncio.CancelledError:
# Task was cancelled, skip it
pass

async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
"""Stream the events and audio data as they're generated."""
Expand All @@ -281,7 +306,7 @@ async def stream(self) -> AsyncIterator[VoiceStreamEvent]:
break

self._check_errors()
self._cleanup_tasks()
await self._cleanup_tasks()

if self._stored_exception:
raise self._stored_exception