Skip to content

Commit f4e007f

Browse files
authored
feat(ibis): Athena default credential chain authentication support (#1362)
1 parent c53150e commit f4e007f

File tree

5 files changed

+179
-18
lines changed

5 files changed

+179
-18
lines changed

ibis-server/app/model/__init__.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,19 +118,55 @@ class AthenaConnectionInfo(BaseConnectionInfo):
118118
description="S3 staging directory for Athena queries",
119119
examples=["s3://my-bucket/athena-staging/"],
120120
)
121-
aws_access_key_id: SecretStr = Field(
122-
description="AWS access key ID", examples=["AKIA..."]
121+
122+
# ── Standard AWS credential chain (optional) ─────────────
123+
aws_access_key_id: SecretStr | None = Field(
124+
description="AWS access key ID. Optional if using IAM role, web identity token, or default credential chain.",
125+
examples=["AKIA..."],
126+
default=None,
127+
)
128+
aws_secret_access_key: SecretStr | None = Field(
129+
description="AWS secret access key. Optional if using IAM role, web identity token, or default credential chain.",
130+
examples=["my-secret-key"],
131+
default=None,
132+
)
133+
aws_session_token: SecretStr | None = Field(
134+
description="AWS session token (used for temporary credentials)",
135+
examples=["IQoJb3JpZ2luX2VjEJz//////////wEaCXVzLWVhc3QtMSJHMEUCIQD..."],
136+
default=None,
137+
)
138+
139+
# ── Web identity federation (OIDC/JWT-based) ─────────────
140+
web_identity_token: SecretStr | None = Field(
141+
description=(
142+
"OIDC web identity token (JWT) used for AssumeRoleWithWebIdentity authentication. "
143+
"If provided, PyAthena will call STS to exchange it for temporary credentials."
144+
),
145+
examples=["eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9..."],
146+
default=None,
123147
)
124-
aws_secret_access_key: SecretStr = Field(
125-
description="AWS secret access key", examples=["my-secret-key"]
148+
role_arn: SecretStr | None = Field(
149+
description="The ARN of the role to assume with the web identity token.",
150+
examples=["arn:aws:iam::123456789012:role/YourAthenaRole"],
151+
default=None,
126152
)
153+
role_session_name: SecretStr | None = Field(
154+
description="The session name when assuming a role (optional).",
155+
examples=["PyAthena-session"],
156+
default=None,
157+
)
158+
159+
# ── Regional and database settings ───────────────────────
127160
region_name: SecretStr = Field(
128-
description="AWS region for Athena", examples=["us-west-2", "us-east-1"]
161+
description="AWS region for Athena. Optional; will use default region if not provided.",
162+
examples=["us-west-2", "us-east-1"],
163+
default=None,
129164
)
130-
schema_name: SecretStr = Field(
165+
schema_name: SecretStr | None = Field(
131166
alias="schema_name",
132-
description="The database name in Athena",
167+
description="The database name in Athena. Defaults to 'default'.",
133168
examples=["default"],
169+
default=SecretStr("default"),
134170
)
135171

136172

ibis-server/app/model/data_source.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Any
99
from urllib.parse import unquote_plus
1010

11+
import boto3
1112
import ibis
1213
from google.cloud import bigquery
1314
from google.oauth2 import service_account
@@ -254,13 +255,55 @@ def get_connection(self, info: ConnectionInfo) -> BaseBackend:
254255

255256
@staticmethod
256257
def get_athena_connection(info: AthenaConnectionInfo) -> BaseBackend:
257-
return ibis.athena.connect(
258-
s3_staging_dir=info.s3_staging_dir.get_secret_value(),
259-
aws_access_key_id=info.aws_access_key_id.get_secret_value(),
260-
aws_secret_access_key=info.aws_secret_access_key.get_secret_value(),
261-
region_name=info.region_name.get_secret_value(),
262-
schema_name=info.schema_name.get_secret_value(),
263-
)
258+
kwargs: dict[str, Any] = {
259+
"s3_staging_dir": info.s3_staging_dir.get_secret_value(),
260+
"schema_name": info.schema_name.get_secret_value(),
261+
}
262+
263+
# ── Region ────────────────────────────────────────────────
264+
if info.region_name:
265+
kwargs["region_name"] = info.region_name.get_secret_value()
266+
267+
# ── Web Identity Token flow (Google OIDC → AWS STS) ───
268+
if info.web_identity_token and info.role_arn:
269+
oidc_token = info.web_identity_token.get_secret_value()
270+
role_arn = info.role_arn.get_secret_value()
271+
session_name = (
272+
info.role_session_name.get_secret_value()
273+
if info.role_session_name
274+
else "wren-oidc-session"
275+
)
276+
region = info.region_name.get_secret_value() if info.region_name else None
277+
sts = boto3.client("sts", region_name=region)
278+
279+
resp = sts.assume_role_with_web_identity(
280+
RoleArn=role_arn,
281+
RoleSessionName=session_name,
282+
WebIdentityToken=oidc_token,
283+
)
284+
285+
creds = resp["Credentials"]
286+
kwargs["aws_access_key_id"] = creds["AccessKeyId"]
287+
kwargs["aws_secret_access_key"] = creds["SecretAccessKey"]
288+
kwargs["aws_session_token"] = creds["SessionToken"]
289+
290+
# ── Standard Access/Secret Keys ───────────────────────
291+
elif info.aws_access_key_id and info.aws_secret_access_key:
292+
kwargs["aws_access_key_id"] = info.aws_access_key_id.get_secret_value()
293+
kwargs["aws_secret_access_key"] = (
294+
info.aws_secret_access_key.get_secret_value()
295+
)
296+
if info.aws_session_token:
297+
kwargs["aws_session_token"] = info.aws_session_token.get_secret_value()
298+
299+
# ── 3️⃣ Default AWS credential chain ───────────────────────
300+
# Nothing needed — PyAthena automatically falls back to:
301+
# - Environment variables
302+
# - ~/.aws/credentials
303+
# - IAM Role (EC2, ECS, Lambda)
304+
305+
# Now connect via Ibis wrapper
306+
return ibis.athena.connect(**kwargs)
264307

265308
@staticmethod
266309
def get_bigquery_connection(info: BigQueryConnectionInfo) -> BaseBackend:

ibis-server/tests/routers/v2/connector/test_athena.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
},
6060
{
6161
"name": "timestamptz",
62-
"expression": "TIMESTAMP '2024-01-01 23:59:59 UTC'",
62+
"expression": "CAST(TIMESTAMP '2024-01-01 23:59:59 UTC' AS timestamp)",
6363
"type": "timestamp",
6464
},
6565
{
@@ -113,7 +113,7 @@ async def test_query(client, manifest_str):
113113
"orderkey": "int64",
114114
"custkey": "int64",
115115
"orderstatus": "string",
116-
"totalprice": "decimal128(15, 2)",
116+
"totalprice": "decimal128(38, 9)",
117117
"orderdate": "date32[day]",
118118
"order_cust_key": "string",
119119
"timestamp": "timestamp[us]",
@@ -153,7 +153,7 @@ async def test_query_glue_database(client, manifest_str):
153153
"orderkey": "int64",
154154
"custkey": "int64",
155155
"orderstatus": "string",
156-
"totalprice": "decimal128(15, 2)",
156+
"totalprice": "decimal128(38, 9)",
157157
"orderdate": "date32[day]",
158158
"order_cust_key": "string",
159159
"timestamp": "timestamp[us]",

ibis-server/tests/routers/v3/connector/athena/conftest.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,39 @@ def connection_info():
3131
}
3232

3333

34+
@pytest.fixture(scope="session")
35+
def connection_info_default_credential_chain():
36+
# Use default authentication (e.g., from environment variables, shared config file, or EC2 instance profile)
37+
access_key = os.getenv("AWS_ACCESS_KEY_ID")
38+
secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
39+
if not access_key or not secret_key:
40+
pytest.skip(
41+
"Skipping default credential chain test: AWS credentials not set in environment"
42+
)
43+
return {
44+
"s3_staging_dir": os.getenv("TEST_ATHENA_S3_STAGING_DIR"),
45+
"region_name": os.getenv("TEST_ATHENA_REGION_NAME", "ap-northeast-1"),
46+
"schema_name": "test",
47+
}
48+
49+
50+
@pytest.fixture(scope="session")
51+
def connection_info_oidc():
52+
web_identity_token = os.getenv("TEST_ATHENA_WEB_IDENTITY_TOKEN")
53+
role_arn = os.getenv("TEST_ATHENA_ROLE_ARN")
54+
55+
if not web_identity_token or not role_arn:
56+
pytest.skip("Skipping OIDC test: web identity token or role ARN not set")
57+
58+
return {
59+
"s3_staging_dir": os.getenv("TEST_ATHENA_OIDC_S3_STAGING_DIR"),
60+
"region_name": os.getenv("TEST_ATHENA_OIDC_REGION_NAME", "us-west-1"),
61+
"schema_name": "test",
62+
"role_arn": role_arn,
63+
"web_identity_token": web_identity_token,
64+
}
65+
66+
3467
@pytest.fixture(autouse=True)
3568
def set_remote_function_list_path():
3669
config = get_config()

ibis-server/tests/routers/v3/connector/athena/test_query.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ async def test_query(client, manifest_str, connection_info):
9090
"orderkey": "int64",
9191
"custkey": "int64",
9292
"orderstatus": "string",
93-
"totalprice": "decimal128(15, 2)",
93+
"totalprice": "decimal128(38, 9)",
9494
"orderdate": "date32[day]",
9595
"order_cust_key": "string",
9696
"timestamp": "timestamp[us]",
@@ -211,3 +211,52 @@ async def test_query_with_dry_run_and_invalid_sql(
211211
)
212212
assert response.status_code == 422
213213
assert response.text is not None
214+
215+
216+
@pytest.mark.parametrize(
217+
"conn_fixture",
218+
[
219+
"connection_info",
220+
"connection_info_default_credential_chain",
221+
"connection_info_oidc",
222+
],
223+
)
224+
async def test_query_athena_modes(client, manifest_str, request, conn_fixture):
225+
connection_info = request.getfixturevalue(conn_fixture)
226+
227+
response = await client.post(
228+
url="/v3/connector/athena/query",
229+
json={
230+
"connectionInfo": connection_info,
231+
"manifestStr": manifest_str,
232+
"sql": "SELECT * FROM wren.public.orders LIMIT 1",
233+
},
234+
)
235+
assert response.status_code == 200
236+
result = response.json()
237+
assert len(result["columns"]) == len(manifest["models"][0]["columns"])
238+
assert len(result["data"]) == 1
239+
240+
assert result["data"][0] == [
241+
1,
242+
36901,
243+
"O",
244+
"173665.47",
245+
"1996-01-02",
246+
"1_36901",
247+
"2024-01-01 23:59:59.000000",
248+
"2024-01-01 23:59:59.000000",
249+
None,
250+
]
251+
252+
assert result["dtypes"] == {
253+
"orderkey": "int64",
254+
"custkey": "int64",
255+
"orderstatus": "string",
256+
"totalprice": "decimal128(38, 9)",
257+
"orderdate": "date32[day]",
258+
"order_cust_key": "string",
259+
"timestamp": "timestamp[us]",
260+
"timestamptz": "timestamp[us]",
261+
"test_null_time": "timestamp[us]",
262+
}

0 commit comments

Comments
 (0)