1414import pyarrow as pa # type: ignore
1515from pyarrow import parquet as pq # type: ignore
1616import tenacity # type: ignore
17+ from s3fs import S3FileSystem # type: ignore
1718
1819from awswrangler import data_types
1920from awswrangler .exceptions import (UnsupportedWriteMode , UnsupportedFileFormat , AthenaQueryError , EmptyS3Object ,
@@ -491,13 +492,13 @@ def _get_query_dtype(self, query_execution_id: str) -> Tuple[Dict[str, str], Lis
491492 return dtype , parse_timestamps , parse_dates , converters
492493
493494 def read_sql_athena (self ,
494- sql ,
495- database = None ,
496- s3_output = None ,
497- max_result_size = None ,
498- workgroup = None ,
499- encryption = None ,
500- kms_key = None ):
495+ sql : str ,
496+ database : Optional [ str ] = None ,
497+ s3_output : Optional [ str ] = None ,
498+ max_result_size : Optional [ int ] = None ,
499+ workgroup : Optional [ str ] = None ,
500+ encryption : Optional [ str ] = None ,
501+ kms_key : Optional [ str ] = None ):
501502 """
502503 Executes any SQL query on AWS Athena and return a Dataframe of the result.
503504 P.S. If max_result_size is passed, then a iterator of Dataframes is returned.
@@ -512,18 +513,21 @@ def read_sql_athena(self,
512513 :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
513514 :return: Pandas Dataframe or Iterator of Pandas Dataframes if max_result_size != None
514515 """
515- if not s3_output :
516- s3_output = self ._session .athena .create_athena_bucket ()
517- query_execution_id = self ._session .athena .run_query (query = sql ,
518- database = database ,
519- s3_output = s3_output ,
520- workgroup = workgroup ,
521- encryption = encryption ,
522- kms_key = kms_key )
523- query_response = self ._session .athena .wait_query (query_execution_id = query_execution_id )
516+ if s3_output is None :
517+ if self ._session .athena_s3_output is not None :
518+ s3_output = self ._session .athena_s3_output
519+ else :
520+ s3_output = self ._session .athena .create_athena_bucket ()
521+ query_execution_id : str = self ._session .athena .run_query (query = sql ,
522+ database = database ,
523+ s3_output = s3_output ,
524+ workgroup = workgroup ,
525+ encryption = encryption ,
526+ kms_key = kms_key )
527+ query_response : Dict = self ._session .athena .wait_query (query_execution_id = query_execution_id )
524528 if query_response ["QueryExecution" ]["Status" ]["State" ] in ["FAILED" , "CANCELLED" ]:
525- reason = query_response ["QueryExecution" ]["Status" ]["StateChangeReason" ]
526- message_error = f"Query error: { reason } "
529+ reason : str = query_response ["QueryExecution" ]["Status" ]["StateChangeReason" ]
530+ message_error : str = f"Query error: { reason } "
527531 raise AthenaQueryError (message_error )
528532 else :
529533 dtype , parse_timestamps , parse_dates , converters = self ._get_query_dtype (
@@ -1133,7 +1137,7 @@ def read_parquet(self,
11331137 path : str ,
11341138 columns : Optional [List [str ]] = None ,
11351139 filters : Optional [Union [List [Tuple [Any ]], List [Tuple [Any ]]]] = None ,
1136- procs_cpu_bound : Optional [int ] = None ):
1140+ procs_cpu_bound : Optional [int ] = None ) -> pd . DataFrame :
11371141 """
11381142 Read parquet data from S3
11391143
@@ -1145,7 +1149,7 @@ def read_parquet(self,
11451149 path = path [:- 1 ] if path [- 1 ] == "/" else path
11461150 procs_cpu_bound = 1 if self ._session .procs_cpu_bound is None else self ._session .procs_cpu_bound if procs_cpu_bound is None else procs_cpu_bound
11471151 use_threads : bool = True if procs_cpu_bound > 1 else False
1148- fs = s3 .get_fs (session_primitives = self ._session .primitives )
1152+ fs : S3FileSystem = s3 .get_fs (session_primitives = self ._session .primitives )
11491153 fs = pa .filesystem ._ensure_filesystem (fs )
11501154 return pq .read_table (source = path , columns = columns , filters = filters ,
11511155 filesystem = fs ).to_pandas (use_threads = use_threads )
0 commit comments