1- from typing import Dict , List , Tuple , Optional , Any , Iterator , Union
1+ from typing import Dict , List , Tuple , Optional , Any , Iterator
22from time import sleep
33import logging
44import re
@@ -25,7 +25,6 @@ def get_query_columns_metadata(self, query_execution_id: str) -> Dict[str, str]:
2525 """
2626 response : Dict = self ._client_athena .get_query_results (QueryExecutionId = query_execution_id , MaxResults = 1 )
2727 col_info : List [Dict [str , str ]] = response ["ResultSet" ]["ResultSetMetadata" ]["ColumnInfo" ]
28- logger .debug (f"col_info: { col_info } " )
2928 return {x ["Name" ]: x ["Type" ] for x in col_info }
3029
3130 def create_athena_bucket (self ):
@@ -42,7 +41,13 @@ def create_athena_bucket(self):
4241 s3_resource .Bucket (s3_output )
4342 return s3_output
4443
45- def run_query (self , query : str , database : Optional [str ] = None , s3_output : Optional [str ] = None , workgroup : Optional [str ] = None , encryption : Optional [str ] = None , kms_key : Optional [str ] = None ) -> str :
44+ def run_query (self ,
45+ query : str ,
46+ database : Optional [str ] = None ,
47+ s3_output : Optional [str ] = None ,
48+ workgroup : Optional [str ] = None ,
49+ encryption : Optional [str ] = None ,
50+ kms_key : Optional [str ] = None ) -> str :
4651 """
4752 Run a SQL Query against AWS Athena
4853 P.S All default values will be inherited from the Session()
@@ -55,7 +60,7 @@ def run_query(self, query: str, database: Optional[str] = None, s3_output: Optio
5560 :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
5661 :return: Query execution ID
5762 """
58- args : Dict [str , Union [ str , Dict [ str , Union [ str , Dict [ str , str ]]]] ] = {"QueryString" : query }
63+ args : Dict [str , Any ] = {"QueryString" : query }
5964
6065 # s3_output
6166 if s3_output is None :
@@ -71,7 +76,9 @@ def run_query(self, query: str, database: Optional[str] = None, s3_output: Optio
7176 if kms_key is not None :
7277 args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = kms_key
7378 elif self ._session .athena_encryption is not None :
74- args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {"EncryptionOption" : self ._session .athena_encryption }
79+ args ["ResultConfiguration" ]["EncryptionConfiguration" ] = {
80+ "EncryptionOption" : self ._session .athena_encryption
81+ }
7582 if self ._session .athena_kms_key is not None :
7683 args ["ResultConfiguration" ]["EncryptionConfiguration" ]["KmsKey" ] = self ._session .athena_kms_key
7784
@@ -113,7 +120,13 @@ def wait_query(self, query_execution_id):
113120 raise QueryCancelled (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
114121 return response
115122
116- def repair_table (self , table : str , database : Optional [str ] = None , s3_output : Optional [str ] = None , workgroup : Optional [str ] = None , encryption : Optional [str ] = None , kms_key : Optional [str ] = None ):
123+ def repair_table (self ,
124+ table : str ,
125+ database : Optional [str ] = None ,
126+ s3_output : Optional [str ] = None ,
127+ workgroup : Optional [str ] = None ,
128+ encryption : Optional [str ] = None ,
129+ kms_key : Optional [str ] = None ):
117130 """
118131 Hive's metastore consistency check
119132 "MSCK REPAIR TABLE table;"
@@ -133,7 +146,12 @@ def repair_table(self, table: str, database: Optional[str] = None, s3_output: Op
133146 :return: Query execution ID
134147 """
135148 query = f"MSCK REPAIR TABLE { table } ;"
136- query_id = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup , encryption = encryption , kms_key = kms_key )
149+ query_id = self .run_query (query = query ,
150+ database = database ,
151+ s3_output = s3_output ,
152+ workgroup = workgroup ,
153+ encryption = encryption ,
154+ kms_key = kms_key )
137155 self .wait_query (query_execution_id = query_id )
138156 return query_id
139157
@@ -174,7 +192,13 @@ def get_results(self, query_execution_id: str) -> Iterator[Dict[str, Any]]:
174192 yield row
175193 next_token = res .get ("NextToken" )
176194
177- def query (self , query : str , database : Optional [str ] = None , s3_output : Optional [str ] = None , workgroup : Optional [str ] = None , encryption : Optional [str ] = None , kms_key : Optional [str ] = None ) -> Iterator [Dict [str , Any ]]:
195+ def query (self ,
196+ query : str ,
197+ database : Optional [str ] = None ,
198+ s3_output : Optional [str ] = None ,
199+ workgroup : Optional [str ] = None ,
200+ encryption : Optional [str ] = None ,
201+ kms_key : Optional [str ] = None ) -> Iterator [Dict [str , Any ]]:
178202 """
179203 Run a SQL Query against AWS Athena and return the result as a Iterator of lists
180204 P.S All default values will be inherited from the Session()
@@ -187,7 +211,12 @@ def query(self, query: str, database: Optional[str] = None, s3_output: Optional[
187211 :param kms_key: For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
188212 :return: Query execution ID
189213 """
190- query_id : str = self .run_query (query = query , database = database , s3_output = s3_output , workgroup = workgroup , encryption = encryption , kms_key = kms_key )
214+ query_id : str = self .run_query (query = query ,
215+ database = database ,
216+ s3_output = s3_output ,
217+ workgroup = workgroup ,
218+ encryption = encryption ,
219+ kms_key = kms_key )
191220 self .wait_query (query_execution_id = query_id )
192221 return self .get_results (query_execution_id = query_id )
193222
0 commit comments