Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions awswrangler/athena/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
get_query_execution,
stop_query_execution,
start_query_execution,
start_query_executions,
wait_query,
)
from awswrangler.athena._spark import create_spark_session, run_spark_calculation
Expand Down Expand Up @@ -53,6 +54,7 @@
"create_ctas_table",
"show_create_table",
"start_query_execution",
"start_query_executions",
"stop_query_execution",
"unload",
"wait_query",
Expand Down
181 changes: 181 additions & 0 deletions awswrangler/athena/_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
from __future__ import annotations

import logging
import os
import time
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Dict,
Expand All @@ -29,6 +31,8 @@

_logger: logging.Logger = logging.getLogger(__name__)

_DEFAULT_MAX_WORKERS = max(4, os.cpu_count() or 4)


@apply_configs
def start_query_execution(
Expand Down Expand Up @@ -169,6 +173,183 @@ def start_query_execution(
return query_execution_id


@apply_configs
def start_query_executions(
sqls: list[str],
database: str | None = None,
s3_output: str | None = None,
workgroup: str = "primary",
encryption: str | None = None,
kms_key: str | None = None,
params: dict[str, typing.Any] | list[str] | None = None,
paramstyle: Literal["qmark", "named"] = "named",
boto3_session: boto3.Session | None = None,
client_request_token: str | list[list[str]] | None = None,
athena_cache_settings: typing.AthenaCacheSettings | None = None,
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
data_source: str | None = None,
wait: bool = False,
check_workgroup: bool = True,
enforce_workgroup: bool = False,
as_iterator: bool = False,
use_threads: bool | int = False,
) -> list[str] | list[dict[str, typing.Any]]:
"""
Start multiple SQL queries against Amazon Athena.

This is the multi-query counterpart to ``start_query_execution``. It supports
per-query caching and idempotent client request tokens, optional workgroup
validation/enforcement, sequential or thread-pooled parallel dispatch, and
either eager (list) or lazy (iterator) consumption. If ``wait=True``, each
query may be awaited to completion within its submission thread.

Parameters
----------
sqls : list[str]
List of SQL queries to execute.
database : str, optional
AWS Glue/Athena database name.
s3_output : str, optional
S3 path where query results will be stored.
workgroup : str, default 'primary'
Athena workgroup name.
encryption : str, optional
One of {'SSE_S3', 'SSE_KMS', 'CSE_KMS'}.
kms_key : str, optional
KMS key ARN/ID, required if using KMS-based encryption.
params : dict or list, optional
Query parameters. Behavior depends on ``paramstyle``.
paramstyle : {'named', 'qmark'}, default 'named'
Parameter substitution style.
boto3_session : boto3.Session, optional
Existing boto3 session. A new session will be created if None.
client_request_token : str | list[str], optional
Idempotency token(s). If a string, suffixed with query index.
athena_cache_settings : dict, optional
Wrangler cache settings for query result reuse.
athena_query_wait_polling_delay : float, default 1.0
Interval between status checks when waiting for queries.
data_source : str, optional
Data catalog name (default 'AwsDataCatalog').
wait : bool, default False
If True, block until each query completes.
check_workgroup : bool, default True
If True, fetch workgroup config from Athena.
enforce_workgroup : bool, default False
If True, enforce workgroup config even when skipping fetch.
as_iterator : bool, default False
If True, return an iterator instead of a list.
use_threads : bool | int, default False
Parallelism:
- False: sequential execution
- True: ``os.cpu_count()`` threads
- int: number of worker threads

Returns
-------
list[str] | list[dict] | Iterator
QueryExecutionIds or execution metadata dicts if ``wait=True``.
"""
session = boto3_session or boto3.Session()

if isinstance(client_request_token, list):
if len(client_request_token) != len(sqls):
raise ValueError("Length of client_request_token list must match number of queries in sqls")
tokens = client_request_token
elif isinstance(client_request_token, str):
tokens = (f"{client_request_token}-{i}" for i in range(len(sqls)))
else:
tokens = [None] * len(sqls)

if paramstyle == "named":
formatted_queries = (_apply_formatter(q, params, "named") for q in sqls)
elif paramstyle == "qmark":
_params_list = params or [None] * len(sqls)
formatted_queries = (_apply_formatter(q, query_params, "qmark") for q, query_params in zip(sqls, _params_list))
else:
raise ValueError("paramstyle must be 'named' or 'qmark'")

if check_workgroup:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup)
else:
wg_config = _WorkGroupConfig(
enforced=enforce_workgroup,
s3_output=s3_output,
encryption=encryption,
kms_key=kms_key,
)

def _submit(item: tuple[tuple[str, list[str] | None], str | None]):
(q, execution_params), token = item
_logger.debug("Executing query:\n%s", q)

if token is None and athena_cache_settings is not None:
cache_info = _check_for_cached_results(
sql=q,
boto3_session=session,
workgroup=workgroup,
athena_cache_settings=athena_cache_settings,
)
_logger.debug("Cache info:\n%s", cache_info)
if cache_info.has_valid_cache and cache_info.query_execution_id is not None:
_logger.debug("Valid cache found. Retrieving...")
return (
wait_query(
query_execution_id=cache_info.query_execution_id,
boto3_session=session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
)
if wait
else cache_info.query_execution_id
)

qid = _start_query_execution(
sql=q,
wg_config=wg_config,
database=database,
data_source=data_source,
s3_output=s3_output,
workgroup=workgroup,
encryption=encryption,
kms_key=kms_key,
execution_params=execution_params,
client_request_token=token,
boto3_session=session,
)

if wait:
return wait_query(
query_execution_id=qid,
boto3_session=session,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
)

return qid

items = zip(formatted_queries, tokens)

if use_threads is False:
results = map(_submit, items)
return results if as_iterator else list(results)

max_workers = _DEFAULT_MAX_WORKERS if use_threads is True else int(use_threads)

if as_iterator:
executor = ThreadPoolExecutor(max_workers=max_workers)
it = executor.map(_submit, items)

def _iter():
try:
yield from it
finally:
executor.shutdown(wait=True)

return _iter()
else:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(_submit, items))


def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = None) -> None:
"""Stop a query execution.

Expand Down
65 changes: 65 additions & 0 deletions awswrangler/athena/_executions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,71 @@ def start_query_execution(
data_source: str | None = ...,
wait: bool,
) -> str | dict[str, Any]: ...
@overload
def start_query_executions(
sqls: list[str],
database: str | None = ...,
s3_output: str | None = ...,
workgroup: str = ...,
encryption: str | None = ...,
kms_key: str | None = ...,
params: dict[str, Any] | list[str] | None = ...,
paramstyle: Literal["qmark", "named"] = ...,
boto3_session: boto3.Session | None = ...,
client_request_token: str | list[list[str]] | None = ...,
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
athena_query_wait_polling_delay: float = ...,
data_source: str | None = ...,
wait: Literal[False] = ...,
check_workgroup: bool = ...,
enforce_workgroup: bool = ...,
as_iterator: bool = ...,
use_threads: bool | int = ...,
) -> list[str]: ...
@overload
def start_query_executions(
sqls: list[str],
*,
database: str | None = ...,
s3_output: str | None = ...,
workgroup: str = ...,
encryption: str | None = ...,
kms_key: str | None = ...,
params: dict[str, Any] | list[str] | None = ...,
paramstyle: Literal["qmark", "named"] = ...,
boto3_session: boto3.Session | None = ...,
client_request_token: str | list[list[str]] | None = ...,
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
athena_query_wait_polling_delay: float = ...,
data_source: str | None = ...,
wait: Literal[True],
check_workgroup: bool = ...,
enforce_workgroup: bool = ...,
as_iterator: bool = ...,
use_threads: bool | int = ...,
) -> list[dict[str, Any]]: ...
@overload
def start_query_executions(
sqls: list[str],
*,
database: str | None = ...,
s3_output: str | None = ...,
workgroup: str = ...,
encryption: str | None = ...,
kms_key: str | None = ...,
params: dict[str, Any] | list[str] | None = ...,
paramstyle: Literal["qmark", "named"] = ...,
boto3_session: boto3.Session | None = ...,
client_request_token: str | list[list[str]] | None = ...,
athena_cache_settings: typing.AthenaCacheSettings | None = ...,
athena_query_wait_polling_delay: float = ...,
data_source: str | None = ...,
wait: bool,
check_workgroup: bool = ...,
enforce_workgroup: bool = ...,
as_iterator: bool = ...,
use_threads: bool | int = ...,
) -> list[str] | list[dict[str, Any]]: ...
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ...
def wait_query(
query_execution_id: str,
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -1708,3 +1708,62 @@ def test_athena_date_recovery(path, glue_database, glue_table):
ctas_approach=False,
)
assert pandas_equals(df, df2)


def test_start_query_executions_ids_and_results(path, glue_database, glue_table):
# Prepare table
wr.s3.to_parquet(
df=get_df(),
path=path,
index=True,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

sqls = [
f"SELECT * FROM {glue_table} LIMIT 1",
f"SELECT COUNT(*) FROM {glue_table}",
]

# Case 1: Sequential, return query IDs
qids = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, use_threads=False)
assert isinstance(qids, list)
assert all(isinstance(qid, str) for qid in qids)
assert len(qids) == len(sqls)

# Case 2: Sequential, wait for results
results = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=False)
assert isinstance(results, list)
assert all(isinstance(r, dict) for r in results)
assert all("Status" in r for r in results)

# Case 3: Parallel execution with threads
results_parallel = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=True, use_threads=True)
assert isinstance(results_parallel, list)
assert all(isinstance(r, dict) for r in results_parallel)


def test_start_query_executions_as_iterator(path, glue_database, glue_table):
# Prepare table
wr.s3.to_parquet(
df=get_df(),
path=path,
index=True,
dataset=True,
mode="overwrite",
database=glue_database,
table=glue_table,
partition_cols=["par0", "par1"],
)

sqls = [f"SELECT * FROM {glue_table} LIMIT 1"]

# Case: as_iterator=True should return a generator-like object
qids_iter = wr.athena.start_query_executions(sqls=sqls, database=glue_database, wait=False, as_iterator=True)
assert not isinstance(qids_iter, list)
qids = list(qids_iter)
assert len(qids) == 1
assert isinstance(qids[0], str)
Loading