Skip to content

Commit 6606873

Browse files
committed
add automatic host guess in download_data
1 parent 61f78c2 commit 6606873

File tree

4 files changed

+57
-2
lines changed

4 files changed

+57
-2
lines changed

CHANGES.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ heasarc
4141
- Add ``query_by_column`` to allow querying of different catalog columns. [#3403]
4242
- Add support for uploading tables when using TAP directly through ``query_tap``. [#3403]
4343
- Improve how maxrec works. If it is bigger than the default server limit, add a TOP statement. [#3403]
44+
- Add automatic guessing for the data host in ``download_data``. [#3403]
4445

4546
alma
4647
^^^^

astroquery/heasarc/core.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,36 @@ def enable_cloud(self, provider='aws', profile=None):
749749

750750
self.s3_client = self.s3_resource.meta.client
751751

752+
def _guess_host(self, host):
753+
"""Guess the host to use for downloading data
754+
755+
Parameters
756+
----------
757+
host : str
758+
The host provided by the user
759+
760+
Returns
761+
-------
762+
host : str
763+
The guessed host
764+
765+
"""
766+
if host in ['heasarc', 'sciserver', 'aws']:
767+
return host
768+
elif host is not None:
769+
raise ValueError(
770+
'host has to be one of heasarc, sciserver, aws or None')
771+
772+
# host is None, so we guess
773+
if os.environ['HOME'] == '/home/idies' and os.path.exists('/FTP/'):
774+
# we are on idies, so we can use sciserver
775+
return 'sciserver'
776+
777+
for var in ['AWS_REGION', 'AWS_DEFAULT_REGION', 'AWS_ROLE_ARN']:
778+
if var in os.environ:
779+
return 'aws'
780+
return 'heasarc'
781+
752782
def download_data(self, links, host='heasarc', location='.'):
753783
"""Download data products in links with a choice of getting the
754784
data from either the heasarc server, sciserver, or the cloud in AWS.
@@ -780,8 +810,8 @@ def download_data(self, links, host='heasarc', location='.'):
780810
if isinstance(links, Row):
781811
links = links.table[[links.index]]
782812

783-
if host not in ['heasarc', 'sciserver', 'aws']:
784-
raise ValueError('host has to be one of heasarc, sciserver, aws')
813+
# guess the host if not provided
814+
host = self._guess_host(host)
785815

786816
host_column = 'access_url' if host == 'heasarc' else host
787817
if host_column not in links.colnames:

astroquery/heasarc/tests/test_heasarc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,29 @@ def test_locate_data_row():
506506
Heasarc.locate_data(table[0:2], catalog_name="xray")
507507

508508

509+
def test__guess_host_default():
510+
# Use a new HeasarcClass object
511+
assert Heasarc._guess_host(host=None) == 'heasarc'
512+
513+
514+
@pytest.mark.parametrize("host", ["heasarc", "sciserver", "aws"])
515+
def test__guess_host_know(host):
516+
# Use a new HeasarcClass object
517+
assert Heasarc._guess_host(host=host) == host
518+
519+
520+
def test__guess_host_sciserver(monkeypatch):
521+
monkeypatch.setenv("HOME", "/home/idies")
522+
monkeypatch.setattr("os.path.exists", lambda path: path.startswith('/FTP'))
523+
assert Heasarc._guess_host(host=None) == 'sciserver'
524+
525+
526+
@pytest.mark.parametrize("var", ["AWS_REGION", "AWS_REGION_DEFAULT", "AWS_ROLE_ARN"])
527+
def test__guess_host_aws(monkeypatch, var):
528+
monkeypatch.setenv("AWS_REGION", var)
529+
assert Heasarc._guess_host(host=None) == 'aws'
530+
531+
509532
def test_download_data__empty():
510533
with pytest.raises(ValueError, match="Input links table is empty"):
511534
Heasarc.download_data(Table())

docs/heasarc/heasarc.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ You can specify where the data are to be downloaded using the ``location`` param
247247

248248
To download the data, you can pass ``links`` table (or row) to `~astroquery.heasarc.HeasarcClass.download_data`,
249249
specifying from where you want the data to be fetched by specifying the ``host`` parameter. By default,
250+
the function will try to guess the best host based on your environment. If it cannot guess, then
250251
the data is fetched from the main HEASARC servers.
251252
The recommendation is to use different hosts depending on where your code is running:
252253
* ``host='sciserver'``: Use this option if you running you analysis on Sciserver. Because

0 commit comments

Comments
 (0)