Skip to content

Commit 06fb9ef

Browse files
committed
Refresh credentials on ExpiredTokenException
1 parent c559f14 commit 06fb9ef

File tree

1 file changed

+51
-11
lines changed

1 file changed

+51
-11
lines changed

lib/blazer/adapters/athena_adapter.rb

+51-11
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def run_statement(statement, comment, bind_params = [])
1414
request_token = Digest::MD5.hexdigest([statement, bind_params.to_json, data_source.id, settings["workgroup"]].compact.join("/"))
1515
statement_name = "blazer_#{request_token}"
1616
begin
17-
client.create_prepared_statement({
17+
create_prepared_statement({
1818
statement_name: statement_name,
1919
work_group: settings["workgroup"],
2020
query_statement: statement
@@ -45,15 +45,15 @@ def run_statement(statement, comment, bind_params = [])
4545
query_options[:work_group] = settings["workgroup"]
4646
end
4747

48-
resp = client.start_query_execution(**query_options)
48+
resp = start_query_execution(**query_options)
4949
query_execution_id = resp.query_execution_id
5050

5151
timeout = data_source.timeout || 300
5252
stop_at = Time.now + timeout
5353
resp = nil
5454

5555
begin
56-
resp = client.get_query_results(
56+
resp = get_query_results(
5757
query_execution_id: query_execution_id
5858
)
5959
rescue Aws::Athena::Errors::InvalidRequestException => e
@@ -119,11 +119,11 @@ def run_statement(statement, comment, bind_params = [])
119119
end
120120

121121
def tables
122-
glue.get_tables(database_name: database).table_list.map(&:name).sort
122+
get_tables(database_name: database).table_list.map(&:name).sort
123123
end
124124

125125
def schema
126-
glue.get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
126+
get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
127127
end
128128

129129
def preview_statement
@@ -154,11 +154,51 @@ def engine_version
154154
end
155155

156156
def fetch_error(query_execution_id)
157-
client.get_query_execution(
157+
get_query_execution(
158158
query_execution_id: query_execution_id
159159
).query_execution.status.state_change_reason
160160
end
161161

162+
def autorefresh_credentials
163+
yield
164+
rescue Aws::Athena::Errors::ExpiredTokenException, Aws::Glue::Errors::ExpiredTokenException
165+
# Clear our cached Athena & Glue clients to force fetching new credentials, and immediately retry
166+
@client = nil
167+
@glue = nil
168+
@client_credentials = nil
169+
yield
170+
end
171+
172+
def get_tables(**options)
173+
autorefresh_credentials do
174+
glue.get_tables(**options)
175+
end
176+
end
177+
178+
def create_prepared_statement(**options)
179+
autorefresh_credentials do
180+
client.create_prepared_statement(**options)
181+
end
182+
end
183+
184+
def start_query_execution(**options)
185+
autorefresh_credentials do
186+
client.start_query_execution(**options)
187+
end
188+
end
189+
190+
def get_query_results(**options)
191+
autorefresh_credentials do
192+
client.get_query_results(**options)
193+
end
194+
end
195+
196+
def get_query_execution(**options)
197+
autorefresh_credentials do
198+
client.get_query_execution(**options)
199+
end
200+
end
201+
162202
def client
163203
@client ||= Aws::Athena::Client.new(**client_options)
164204
end
@@ -168,12 +208,12 @@ def glue
168208
end
169209

170210
def client_options
171-
@client_options ||= begin
172-
options = {}
173-
options[:credentials] = client_credentials if client_credentials
174-
options[:region] = settings["region"] if settings["region"]
175-
options
211+
options = {}
212+
if credentials = client_credentials
213+
options[:credentials] = credentials
176214
end
215+
options[:region] = settings["region"] if settings["region"]
216+
options
177217
end
178218

179219
def client_credentials

0 commit comments

Comments
 (0)