Skip to content

fix(aws): improve credential handling and client lifecycle (#1719) - Version 0.x #1837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,66 @@
import base64
import inspect
import json
import logging
import os
from typing import Any, Dict, List, Optional, Tuple, get_args, get_origin

import boto3
import aioboto3
from botocore.exceptions import NoCredentialsError
from livekit import rtc
from livekit.agents import llm, utils
from livekit.agents.llm.function_context import _is_optional_type

__all__ = ["_build_aws_ctx", "_build_tools", "_get_aws_credentials"]
__all__ = ["_build_aws_ctx", "_build_tools", "_get_aws_async_session"]


def _get_aws_credentials(
api_key: Optional[str], api_secret: Optional[str], region: Optional[str]
):
region = region or os.environ.get("AWS_DEFAULT_REGION")
def _get_aws_async_session(
api_key: str | None = None,
api_secret: str | None = None,
region: str | None = None,
) -> aioboto3.Session:
"""Get an AWS session with the given credentials and region.

Args:
api_key: AWS access key id.
api_secret: AWS secret access key.
region: AWS region.

Returns:
An AWS session.

Raises:
NoCredentialsError: If no valid credentials are found.
"""
# Validate AWS region first
region = (
region or os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION")
)

if not region:
raise ValueError(
"AWS_DEFAULT_REGION must be set using the argument or by setting the AWS_DEFAULT_REGION environment variable."
"AWS region must be set using the argument or by setting the AWS_REGION environment variable."
)

# If API key and secret are provided, create a session with them
session_params = {"region_name": region}
if api_key and api_secret:
session = boto3.Session(
aws_access_key_id=api_key,
aws_secret_access_key=api_secret,
region_name=region,
session_params.update(
{
"aws_access_key_id": api_key,
"aws_secret_access_key": api_secret,
}
)
else:
session = boto3.Session(region_name=region)

credentials = session.get_credentials()
if not credentials or not credentials.access_key or not credentials.secret_key:
raise ValueError("No valid AWS credentials found.")
return credentials.access_key, credentials.secret_key
session = aioboto3.Session(**session_params)

# Validate session by checking if we can get credentials
try:
session.get_credentials()
except NoCredentialsError as e:
logging.error("Unable to locate AWS credentials")
raise e

return session


JSON_SCHEMA_TYPE_MAP: Dict[type, str] = {
Expand Down
75 changes: 30 additions & 45 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,12 @@
from dataclasses import dataclass
from typing import Any, Literal, MutableSet, Union

import boto3
from livekit.agents import (
APIConnectionError,
APIStatusError,
llm,
)
import aioboto3
from livekit.agents import APIConnectionError, APIStatusError, llm
from livekit.agents.llm import LLMCapabilities, ToolChoice, _create_ai_function_info
from livekit.agents.types import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions

from ._utils import _build_aws_ctx, _build_tools, _get_aws_credentials
from ._utils import _build_aws_ctx, _build_tools, _get_aws_async_session
from .log import logger

TEXT_MODEL = Literal["anthropic.claude-3-5-sonnet-20241022-v2:0"]
Expand Down Expand Up @@ -58,6 +54,7 @@ def __init__(
top_p: float | None = None,
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = "auto",
additional_request_fields: dict[str, Any] | None = None,
session: aioboto3.Session | None = None,
) -> None:
"""
Create a new instance of AWS Bedrock LLM.
Expand All @@ -77,16 +74,14 @@ def __init__(
top_p (float, optional): The nucleus sampling probability for response generation. Defaults to None.
tool_choice (ToolChoice or Literal["auto", "required", "none"], optional): Specifies whether to use tools during response generation. Defaults to "auto".
additional_request_fields (dict[str, Any], optional): Additional request fields to send to the AWS Bedrock Converse API. Defaults to None.
session (aioboto3.Session, optional): Optional aioboto3 session to use.
"""
super().__init__(
capabilities=LLMCapabilities(
supports_choices_on_int=True,
requires_persistent_functions=True,
)
)
self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, region
)

self._model = model or os.environ.get("BEDROCK_INFERENCE_PROFILE_ARN")
if not self._model:
Expand All @@ -103,6 +98,9 @@ def __init__(
)
self._region = region
self._running_fncs: MutableSet[asyncio.Task[Any]] = set()
self._session = session or _get_aws_async_session(
api_key=api_key, api_secret=api_secret, region=region
)

def chat(
self,
Expand All @@ -125,9 +123,6 @@ def chat(
return LLMStream(
self,
model=self._opts.model,
aws_access_key_id=self._api_key,
aws_secret_access_key=self._api_secret,
region_name=self._region,
max_output_tokens=self._opts.max_output_tokens,
top_p=self._opts.top_p,
additional_request_fields=self._opts.additional_request_fields,
Expand All @@ -145,9 +140,6 @@ def __init__(
llm: LLM,
*,
model: str | TEXT_MODEL,
aws_access_key_id: str | None,
aws_secret_access_key: str | None,
region_name: str,
chat_ctx: llm.ChatContext,
conn_options: APIConnectOptions,
fnc_ctx: llm.FunctionContext | None,
Expand All @@ -160,12 +152,7 @@ def __init__(
super().__init__(
llm, chat_ctx=chat_ctx, fnc_ctx=fnc_ctx, conn_options=conn_options
)
self._client = boto3.client(
"bedrock-runtime",
region_name=region_name,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
self._client = llm._session.client("bedrock-runtime")
self._model = model
self._llm: LLM = llm
self._max_output_tokens = max_output_tokens
Expand Down Expand Up @@ -222,29 +209,27 @@ def _get_tool_config() -> dict[str, Any] | None:
"topP": self._top_p,
}
)
response = self._client.converse_stream(
modelId=self._model,
messages=messages,
inferenceConfig=inference_config,
**_strip_nones(opts),
) # type: ignore

request_id = response["ResponseMetadata"]["RequestId"]
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise APIStatusError(
f"aws bedrock llm: error generating content: {response}",
retryable=False,
request_id=request_id,
)

for chunk in response["stream"]:
chat_chunk = self._parse_chunk(request_id, chunk)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)

# Let other coroutines run
await asyncio.sleep(0)
async with self._client as client:
response = await client.converse_stream(
modelId=self._model,
messages=messages,
inferenceConfig=inference_config,
**_strip_nones(opts),
) # type: ignore

request_id = response["ResponseMetadata"]["RequestId"]
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise APIStatusError(
f"aws bedrock llm: error generating content: {response}",
retryable=False,
request_id=request_id,
)

async for chunk in response["stream"]:
chat_chunk = self._parse_chunk(request_id, chunk)
if chat_chunk is not None:
retryable = False
self._event_ch.send_nowait(chat_chunk)

except Exception as e:
raise APIConnectionError(
Expand Down
48 changes: 36 additions & 12 deletions livekit-plugins/livekit-plugins-aws/livekit/plugins/aws/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@
from dataclasses import dataclass
from typing import Optional

from amazon_transcribe.auth import StaticCredentialResolver
from amazon_transcribe.client import TranscribeStreamingClient
from amazon_transcribe.model import Result, TranscriptEvent
from livekit import rtc
from livekit.agents import (
DEFAULT_API_CONNECT_OPTIONS,
APIConnectOptions,
stt,
utils,
)

from ._utils import _get_aws_credentials
from livekit.agents import DEFAULT_API_CONNECT_OPTIONS, APIConnectOptions, stt, utils

from ._utils import _get_aws_async_session
from .log import logger


Expand Down Expand Up @@ -73,9 +69,12 @@ def __init__(
capabilities=stt.STTCapabilities(streaming=True, interim_results=True)
)

self._api_key, self._api_secret = _get_aws_credentials(
api_key, api_secret, speech_region
self._session = _get_aws_async_session(
api_key=api_key,
api_secret=api_secret,
region=speech_region,
)

self._config = STTOptions(
speech_region=speech_region,
language=language,
Expand Down Expand Up @@ -116,6 +115,19 @@ def stream(
opts=self._config,
)

async def _get_client(self) -> TranscribeStreamingClient:
"""Get a new TranscribeStreamingClient instance."""
credentials = await self._session.get_credentials()
frozen_credentials = await credentials.get_frozen_credentials()
self.cred_resolver = StaticCredentialResolver(
access_key_id=frozen_credentials.access_key,
secret_access_key=frozen_credentials.secret_key,
session_token=frozen_credentials.token,
)
return TranscribeStreamingClient(
region=self._config.speech_region, credential_resolver=self.cred_resolver
)


class SpeechStream(stt.SpeechStream):
def __init__(
Expand All @@ -128,10 +140,22 @@ def __init__(
stt=stt, conn_options=conn_options, sample_rate=opts.sample_rate
)
self._opts = opts
self._client = TranscribeStreamingClient(region=self._opts.speech_region)
self._stt = stt
self._client = None
self._last_credential_time = 0

async def _initialize_client(self):
# Check if we need to refresh credentials (every 10 minutes or if client is None)
current_time = asyncio.get_event_loop().time()
if self._client is None or (current_time - self._last_credential_time > 600):
self._client = await self._stt._get_client()
self._last_credential_time = current_time
return self._client

async def _run(self) -> None:
stream = await self._client.start_stream_transcription(
client = await self._initialize_client()

stream = await client.start_stream_transcription(
language_code=self._opts.language,
media_sample_rate_hz=self._opts.sample_rate,
media_encoding=self._opts.encoding,
Expand Down
Loading