Skip to content

Commit a82b7ba

Browse files
author
Michael Baumann
committed
Only enable requester pays when necessary #400
Only enable (Google) requester pays data access when the given DRS URIs require it and the platform TNU is running on supports it.
1 parent d84d3e5 commit a82b7ba

File tree

1 file changed

+55
-12
lines changed

1 file changed

+55
-12
lines changed

terra_notebook_utils/drs.py

+55-12
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
from terra_notebook_utils import WORKSPACE_BUCKET, WORKSPACE_NAME, DRS_RESOLVER_URL, WORKSPACE_NAMESPACE, \
1111
WORKSPACE_GOOGLE_PROJECT
12-
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA
13-
from terra_notebook_utils.utils import is_notebook
12+
from terra_notebook_utils import workspace, gs, tar_gz, TERRA_DEPLOYMENT_ENV, _GS_SCHEMA, ExecutionPlatform
13+
from terra_notebook_utils.utils import is_notebook, get_execution_context
1414
from terra_notebook_utils.http import http
1515
from terra_notebook_utils.blobstore.gs import GSBlob
1616
from terra_notebook_utils.blobstore.local import LocalBlob
@@ -26,16 +26,46 @@
2626
class DRSResolutionError(Exception):
2727
pass
2828

29+
class RequesterPaysNotSupported(Exception):
30+
pass
31+
2932
def _parse_gs_url(gs_url: str) -> Tuple[str, str]:
3033
if gs_url.startswith(_GS_SCHEMA):
3134
bucket_name, object_key = gs_url[len(_GS_SCHEMA):].split("/", 1)
3235
return bucket_name, object_key
3336
else:
3437
raise RuntimeError(f'Invalid gs url schema. {gs_url} does not start with {_GS_SCHEMA}')
3538

39+
40+
def is_requester_pays(drs_urls: Iterable[str]) -> bool:
41+
"""
42+
Identify if any of the given DRS URIs require Google requester pays access
43+
44+
:raises: RequesterPaysNotSupported
45+
"""
46+
for drs_url in drs_urls:
47+
# Currently (1/2023), Gen3-hosted AnVIL data in GCS is the only
48+
# DRS data requiring requester pays access.
49+
# Even this will end in when the AnVIL data is hosted in TDR.
50+
# Note: If the Gen3 AnVIL DRS URI format is retained by TDR, this function must be updated.
51+
ANVIL_DRS_URI_PREFIX = "drs://dg.ANV0"
52+
if drs_url.strip().startswith(ANVIL_DRS_URI_PREFIX):
53+
if get_execution_context().execution_platform == ExecutionPlatform.AZURE:
54+
raise RequesterPaysNotSupported(
55+
f"Requester pays data access is not supported on the Azure platform. Cannot access: {drs_url}"
56+
)
57+
else:
58+
return True
59+
return False
60+
61+
3662
@lru_cache()
3763
def enable_requester_pays(workspace_name: Optional[str]=WORKSPACE_NAME,
3864
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):
65+
66+
assert get_execution_context().execution_platform != ExecutionPlatform.AZURE, \
67+
"Requester pays data access is not supported on the Terra Azure platform."
68+
3969
if not workspace_name:
4070
raise RuntimeError('Workspace name is not set. Please set the environment variable '
4171
'WORKSPACE_NAME with the name of a valid Terra Workspace.')
@@ -92,11 +122,14 @@ def access(drs_url: str,
92122
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
93123
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT) -> str:
94124
"""Return a signed url for a drs:// URI, if available."""
95-
# We enable requester pays by specifying the workspace/namespace combo, not
96-
# with the billing project. Rawls then enables requester pays for the attached
97-
# project, but this won't work if a user specifies a project unattached to
98-
# the Terra workspace.
99-
enable_requester_pays(workspace_name, workspace_namespace)
125+
126+
if is_requester_pays([drs_url]):
127+
# We enable requester pays by specifying the workspace/namespace combo, not
128+
# with the billing project. Rawls then enables requester pays for the attached
129+
# project, but this won't work if a user specifies a project unattached to
130+
# the Terra workspace.
131+
enable_requester_pays(workspace_name, workspace_namespace)
132+
100133
info = get_drs_info(drs_url, access_url=True)
101134

102135
if info.access_url:
@@ -232,7 +265,9 @@ def head(drs_url: str,
232265
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE,
233266
billing_project: Optional[str]=WORKSPACE_GOOGLE_PROJECT):
234267
"""Head a DRS object by byte."""
235-
enable_requester_pays(workspace_name, workspace_namespace)
268+
269+
if is_requester_pays([drs_url]):
270+
enable_requester_pays(workspace_name, workspace_namespace)
236271
try:
237272
blob = get_drs_blob(drs_url, billing_project)
238273
with blob.open(chunk_size=num_bytes) as fh:
@@ -295,7 +330,8 @@ def copy(drs_uri: str,
295330
"""Copy a DRS object to either the local filesystem, or to a Google Storage location if `dst` starts with
296331
"gs://".
297332
"""
298-
enable_requester_pays(workspace_name, workspace_namespace)
333+
if is_requester_pays([drs_uri]):
334+
enable_requester_pays(workspace_name, workspace_namespace)
299335
with DRSCopyClient(raise_on_error=True, indicator_type=indicator_type) as cc:
300336
cc.workspace = workspace_name
301337
cc.workspace_namespace = workspace_namespace
@@ -340,7 +376,8 @@ def copy_batch_urls(drs_urls: Iterable[str],
340376
indicator_type: Indicator = Indicator.notebook_bar if is_notebook() else Indicator.log,
341377
workspace_name: Optional[str] = WORKSPACE_NAME,
342378
workspace_namespace: Optional[str] = WORKSPACE_NAMESPACE):
343-
enable_requester_pays(workspace_name, workspace_namespace)
379+
if is_requester_pays(drs_urls):
380+
enable_requester_pays(workspace_name, workspace_namespace)
344381
with DRSCopyClient(indicator_type=indicator_type) as cc:
345382
cc.workspace = workspace_name
346383
cc.workspace_namespace = workspace_namespace
@@ -365,7 +402,11 @@ def copy_batch_manifest(manifest: List[Dict[str, str]],
365402
workspace_namespace: Optional[str]=WORKSPACE_NAMESPACE):
366403
from jsonschema import validate
367404
validate(instance=manifest, schema=manifest_schema)
368-
enable_requester_pays(workspace_name, workspace_namespace)
405+
406+
drs_uri_list = [item['drs_uri'] for item in manifest]
407+
if is_requester_pays(drs_uri_list):
408+
enable_requester_pays(workspace_name, workspace_namespace)
409+
369410
with DRSCopyClient(indicator_type=indicator_type) as cc:
370411
cc.workspace = workspace_name
371412
cc.workspace_namespace = workspace_namespace
@@ -381,7 +422,9 @@ def extract_tar_gz(drs_url: str,
381422
Default extraction is to the bucket for 'workspace'.
382423
"""
383424
dst = dst or f"gs://{workspace.get_workspace_bucket(workspace_name)}"
384-
enable_requester_pays(workspace_name, workspace_namespace)
425+
426+
if is_requester_pays([drs_url]):
427+
enable_requester_pays(workspace_name, workspace_namespace)
385428
blob = get_drs_blob(drs_url, billing_project)
386429
with blob.open() as fh:
387430
tar_gz.extract(fh, dst)

0 commit comments

Comments
 (0)