Skip to content

Commit b09a3c9

Browse files
committed
Refresh credentials on ExpiredTokenException
1 parent 0c97d0e commit b09a3c9

File tree

1 file changed

+69
-32
lines changed

1 file changed

+69
-32
lines changed

lib/blazer/adapters/athena_adapter.rb

+69-32
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def run_statement(statement, comment, bind_params = [])
1414
request_token = Digest::MD5.hexdigest([statement, bind_params.to_json, data_source.id, settings["workgroup"]].compact.join("/"))
1515
statement_name = "blazer_#{request_token}"
1616
begin
17-
client.create_prepared_statement({
17+
create_prepared_statement({
1818
statement_name: statement_name,
1919
work_group: settings["workgroup"],
2020
query_statement: statement
@@ -45,15 +45,15 @@ def run_statement(statement, comment, bind_params = [])
4545
query_options[:work_group] = settings["workgroup"]
4646
end
4747

48-
resp = client.start_query_execution(**query_options)
48+
resp = start_query_execution(**query_options)
4949
query_execution_id = resp.query_execution_id
5050

5151
timeout = data_source.timeout || 300
5252
stop_at = Time.now + timeout
5353
resp = nil
5454

5555
begin
56-
resp = client.get_query_results(
56+
resp = get_query_results(
5757
query_execution_id: query_execution_id
5858
)
5959
rescue Aws::Athena::Errors::InvalidRequestException => e
@@ -119,11 +119,11 @@ def run_statement(statement, comment, bind_params = [])
119119
end
120120

121121
def tables
122-
glue.get_tables(database_name: database).table_list.map(&:name).sort
122+
get_tables(database_name: database).table_list.map(&:name).sort
123123
end
124124

125125
def schema
126-
glue.get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
126+
get_tables(database_name: database).table_list.map { |t| {table: t.name, columns: t.storage_descriptor.columns.map { |c| {name: c.name, data_type: c.type} }} }
127127
end
128128

129129
def preview_statement
@@ -154,11 +154,50 @@ def engine_version
154154
end
155155

156156
def fetch_error(query_execution_id)
157-
client.get_query_execution(
157+
get_query_execution(
158158
query_execution_id: query_execution_id
159159
).query_execution.status.state_change_reason
160160
end
161161

162+
def autorefresh_credentials
163+
yield
164+
rescue Aws::Athena::Errors::ExpiredTokenException
165+
# Clear our cached Athena & Glue clients to force fetching new credentials, and immediately retry
166+
@client = nil
167+
@glue = nil
168+
yield
169+
end
170+
171+
def get_tables(**options)
172+
autorefresh_credentials do
173+
glue.get_tables(**options)
174+
end
175+
end
176+
177+
def create_prepared_statement(**options)
178+
autorefresh_credentials do
179+
client.create_prepared_statement(**options)
180+
end
181+
end
182+
183+
def start_query_execution(**options)
184+
autorefresh_credentials do
185+
client.start_query_execution(**options)
186+
end
187+
end
188+
189+
def get_query_results(**options)
190+
autorefresh_credentials do
191+
client.get_query_results(**options)
192+
end
193+
end
194+
195+
def get_query_execution(**options)
196+
autorefresh_credentials do
197+
client.get_query_execution(**options)
198+
end
199+
end
200+
162201
def client
163202
@client ||= Aws::Athena::Client.new(**client_options)
164203
end
@@ -168,36 +207,34 @@ def glue
168207
end
169208

170209
def client_options
171-
@client_options ||= begin
172-
options = {}
173-
options[:credentials] = client_credentials if client_credentials
174-
options[:region] = settings["region"] if settings["region"]
175-
options
210+
options = {}
211+
if credentials = client_credentials
212+
options[:credentials] = credentials
176213
end
214+
options[:region] = settings["region"] if settings["region"]
215+
options
177216
end
178217

179218
def client_credentials
180-
@client_credentials ||= begin
181-
# Loading the access key & secret from the top-level settings is supported for backwards compatibility,
182-
# but prefer loading them from the 'credentials' sub-hash.
183-
creds = (settings["credentials"] || {}).with_defaults(settings.slice("access_key_id", "secret_access_key", "region"))
184-
access_key_id = creds["access_key_id"]
185-
secret_access_key = creds["secret_access_key"]
186-
region = creds["region"]
187-
role_arn = creds["role_arn"]
188-
role_session_name = creds["role_session_name"] || "blazer"
189-
if role_arn
190-
Aws::STS::Client.new(
191-
access_key_id: access_key_id,
192-
secret_access_key: secret_access_key,
193-
region: region,
194-
).assume_role(
195-
role_arn: role_arn,
196-
role_session_name: role_session_name,
197-
)
198-
elsif access_key_id && secret_access_key
199-
Aws::Credentials.new(access_key_id, secret_access_key)
200-
end
219+
# Loading the access key & secret from the top-level settings is supported for backwards compatibility,
220+
# but prefer loading them from the 'credentials' sub-hash.
221+
creds = (settings["credentials"] || {}).with_defaults(settings.slice("access_key_id", "secret_access_key", "region"))
222+
access_key_id = creds["access_key_id"]
223+
secret_access_key = creds["secret_access_key"]
224+
region = creds["region"]
225+
role_arn = creds["role_arn"]
226+
role_session_name = creds["role_session_name"] || "blazer"
227+
if role_arn
228+
Aws::STS::Client.new(
229+
access_key_id: access_key_id,
230+
secret_access_key: secret_access_key,
231+
region: region,
232+
).assume_role(
233+
role_arn: role_arn,
234+
role_session_name: role_session_name,
235+
)
236+
elsif access_key_id && secret_access_key
237+
Aws::Credentials.new(access_key_id, secret_access_key)
201238
end
202239
end
203240
end

0 commit comments

Comments
 (0)