Skip to content

Commit 44bf250

Browse files
author
Clara Brasseur
committed
additional caching functionality
1 parent 9f3ed5b commit 44bf250

File tree

3 files changed

+36
-23
lines changed

3 files changed

+36
-23
lines changed

astroquery/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from ._astropy_init import *
1313
# ----------------------------------------------------------------------------
1414

15-
1615
import os
1716
from astropy import config as _config
1817

@@ -34,9 +33,16 @@ def _get_bibtex():
3433
class Conf(_config.ConfigNamespace):
3534

3635
default_cache_timeout = _config.ConfigItem(
37-
60.0*60.0*24.0,
36+
86400, # 24 hours
3837
'Astroquery-wide default cache timeout (seconds).'
3938
)
40-
39+
cache_location = _config.ConfigItem(
40+
os.path.join(_config.paths.get_cache_dir(), 'astroquery'),
41+
'Astroquery default cache location (within astropy cache).'
42+
)
43+
use_cache = _config.ConfigItem(
44+
True,
45+
"Astroquery global cache usage, False turns off all caching."
46+
)
4147

4248
conf = Conf()

astroquery/query.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,12 @@ def request_file(self, cache_location):
107107
def from_cache(self, cache_location):
108108
request_file = self.request_file(cache_location)
109109
try:
110-
current_time = datetime.utcnow()
111-
cache_time = datetime.utcfromtimestamp(os.path.getmtime(request_file))
112-
expired = ((current_time-cache_time) > timedelta(seconds=conf.default_cache_timeout))
110+
if conf.default_cache_timeout is None:
111+
expired = False
112+
else:
113+
current_time = datetime.utcnow()
114+
cache_time = datetime.utcfromtimestamp(os.path.getmtime(request_file))
115+
expired = ((current_time-cache_time) > timedelta(seconds=conf.default_cache_timeout))
113116
if not expired:
114117
with open(request_file, "rb") as f:
115118
response = pickle.load(f)
@@ -166,20 +169,15 @@ def __init__(self):
166169
'astroquery/{vers} {olduseragent}'
167170
.format(vers=version.version,
168171
olduseragent=S.headers['User-Agent']))
169-
170-
self.cache_location = os.path.join(
171-
paths.get_cache_dir(), 'astroquery',
172-
self.__class__.__name__.split("Class")[0])
173-
if not os.path.exists(self.cache_location):
174-
os.makedirs(self.cache_location)
175-
self._cache_active = True
172+
self.name = self.__class__.__name__.split("Class")[0]
173+
self._cache_active = conf.use_cache
176174

177175
def __call__(self, *args, **kwargs):
178176
""" init a fresh copy of self """
179177
return self.__class__(*args, **kwargs)
180178

181179
def _request(self, method, url, params=None, data=None, headers=None,
182-
files=None, save=False, savedir='', timeout=None, cache=True,
180+
files=None, save=False, savedir='', timeout=None, cache=None,
183181
stream=False, auth=None, continuation=True, verify=True):
184182
"""
185183
A generic HTTP request method, similar to `requests.Session.request`
@@ -210,6 +208,7 @@ def _request(self, method, url, params=None, data=None, headers=None,
210208
somewhere other than `BaseQuery.cache_location`
211209
timeout : int
212210
cache : bool
211+
Override global cache settings.
213212
verify : bool
214213
Verify the server's TLS certificate?
215214
(see http://docs.python-requests.org/en/master/_modules/requests/sessions/?highlight=verify)
@@ -234,32 +233,40 @@ def _request(self, method, url, params=None, data=None, headers=None,
234233
files=files,
235234
timeout=timeout
236235
)
236+
237+
# Set up cache
238+
if (cache is True) or ((cache is not False) and conf.use_cache):
239+
cache_location = os.path.join(conf.cache_location, self.name)
240+
cache = True
241+
else:
242+
cache_location = None
243+
cache = False
244+
237245
if save:
238246
local_filename = url.split('/')[-1]
239247
if os.name == 'nt':
240248
# Windows doesn't allow special characters in filenames like
241249
# ":" so replace them with an underscore
242250
local_filename = local_filename.replace(':', '_')
243-
local_filepath = os.path.join(self.cache_location or savedir or '.', local_filename)
251+
local_filepath = os.path.join(savedir or cache_location or '.', local_filename)
244252
self._download_file(url, local_filepath, cache=cache,
245253
continuation=continuation, method=method,
246254
auth=auth, **req_kwargs)
247255
return local_filepath
248256
else:
249257
query = AstroQuery(method, url, **req_kwargs)
250-
if ((self.cache_location is None) or (not self._cache_active) or (not cache)):
251-
with suspend_cache(self):
252-
response = query.request(self._session, stream=stream,
253-
auth=auth, verify=verify)
258+
if ((cache_location is None) or (not cache)):
259+
response = query.request(self._session, stream=stream,
260+
auth=auth, verify=verify)
254261
else:
255-
response = query.from_cache(self.cache_location)
262+
response = query.from_cache(cache_location)
256263
if not response:
257264
response = query.request(self._session,
258-
self.cache_location,
265+
cache_location,
259266
stream=stream,
260267
auth=auth,
261268
verify=verify)
262-
to_cache(response, query.request_file(self.cache_location))
269+
to_cache(response, query.request_file(cache_location))
263270
self._last_query = query
264271
return response
265272

@@ -281,6 +288,7 @@ def _download_file(self, url, local_filepath, timeout=None, auth=None,
281288
supports HTTP "range" requests, the download will be continued
282289
where it left off.
283290
cache : bool
291+
Cache downloaded file. Defaults to False.
284292
method : "GET" or "POST"
285293
head_safe : bool
286294
"""

astroquery/setup_package.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,3 @@
44

55
def get_package_data():
66
return {'astroquery': ['astroquery.cfg', 'CITATION']}
7-

0 commit comments

Comments
 (0)