|
7 | 7 | import os |
8 | 8 | import requests |
9 | 9 | import warnings |
10 | | - |
11 | 10 | import yaml |
| 11 | +from typing import Iterable, List, Dict, Any, Optional |
| 12 | + |
| 13 | + |
| 14 | + |
| 15 | +def _gather_files_from_response(resp: requests.Response) -> List[Dict[str, Any]]: |
| 16 | + """ |
| 17 | + Normalize Figshare API responses into a list of file dicts. |
| 18 | +
|
| 19 | + Supports: |
| 20 | + 1) Article endpoint: https://api.figshare.com/v2/articles/{id} |
| 21 | + -> JSON object with key 'files' (list) |
| 22 | +
|
| 23 | + 2) Files endpoint: https://api.figshare.com/v2/articles/{id}/files[?...] |
| 24 | + -> JSON list of file objects (possibly paginated with Link headers) |
| 25 | + """ |
| 26 | + data = resp.json() |
| 27 | + if isinstance(data, dict) and "files" in data and isinstance(data["files"], list): |
| 28 | + return data["files"] |
| 29 | + if isinstance(data, list): |
| 30 | + return data |
| 31 | + raise ValueError("Unexpected Figshare API response structure; expected dict with 'files' " |
| 32 | + "or a list of file objects.") |
| 33 | + |
| 34 | + |
| 35 | +def _iter_paginated_files(url: str, session: Optional[requests.Session] = None) -> Iterable[Dict[str, Any]]: |
| 36 | + """ |
| 37 | + Iterate over all files, following 'Link: <...>; rel=\"next\"' pagination if present. |
| 38 | + Works for both the article endpoint (no pagination) and the files endpoint (may paginate). |
| 39 | + """ |
| 40 | + sess = session or requests.Session() |
| 41 | + next_url = url |
| 42 | + |
| 43 | + while next_url: |
| 44 | + resp = sess.get(next_url) |
| 45 | + if resp.status_code != 200: |
| 46 | + raise Exception(f"Failed to get dataset details from Figshare: {resp.text}") |
| 47 | + |
| 48 | + for f in _gather_files_from_response(resp): |
| 49 | + yield f |
| 50 | + |
| 51 | + # RFC5988-style 'Link' header pagination |
| 52 | + link = resp.headers.get("Link") or resp.headers.get("link") |
| 53 | + next_url = None |
| 54 | + if link: |
| 55 | + parts = [p.strip() for p in link.split(",")] |
| 56 | + for part in parts: |
| 57 | + if 'rel="next"' in part: |
| 58 | + start = part.find("<") + 1 |
| 59 | + end = part.find(">", start) |
| 60 | + if start > 0 and end > start: |
| 61 | + next_url = part[start:end] |
| 62 | + break |
12 | 63 |
|
13 | 64 | def download( |
14 | 65 | name: str='all', |
@@ -46,81 +97,73 @@ def download( |
46 | 97 | local_path = Path(local_path) |
47 | 98 |
|
48 | 99 | if not local_path.exists(): |
49 | | - Path.mkdir(local_path) |
| 100 | + local_path.mkdir(parents=True, exist_ok=True) |
50 | 101 | # Get the dataset details |
51 | 102 | with resources.open_text('coderdata', 'dataset.yml') as f: |
52 | 103 | data_information = yaml.load(f, Loader=yaml.FullLoader) |
53 | 104 | url = data_information['figshare'] |
54 | | - |
55 | | - response = requests.get(url) |
56 | | - if response.status_code != 200: |
57 | | - raise Exception( |
58 | | - f"Failed to get dataset details from Figshare: {response.text}" |
59 | | - ) |
60 | | - |
61 | | - data = response.json() |
62 | 105 |
|
63 | | - # making sure that we are case insensitive |
64 | | - name = name.casefold() |
| 106 | + name = (name or "all").casefold() |
| 107 | + session = requests.Session() |
| 108 | + all_files = list(_iter_paginated_files(url, session=session)) |
65 | 109 |
|
66 | | - # Filter files by the specified prefix |
67 | 110 | if name != "all": |
68 | 111 | filtered_files = [ |
69 | | - file |
70 | | - for file |
71 | | - in data['files'] |
72 | | - if file['name'].startswith(name) or 'genes' in file['name'] |
73 | | - ] |
| 112 | + f for f in all_files |
| 113 | + if (f.get('name', '').casefold().startswith(name)) or ('genes' in f.get('name', '').casefold()) |
| 114 | + ] |
74 | 115 | else: |
75 | | - filtered_files = data['files'] |
| 116 | + filtered_files = all_files |
76 | 117 |
|
77 | | - # Group files by name and select the one with the highest ID |
78 | 118 | unique_files = {} |
79 | 119 | for file in filtered_files: |
80 | | - file_name = local_path.joinpath(file['name']) |
81 | | - file_id = file['id'] |
82 | | - if ( |
83 | | - file_name not in unique_files |
84 | | - or file_id > unique_files[file_name]['id'] |
85 | | - ): |
86 | | - unique_files[file_name] = {'file_info': file, 'id': file_id} |
| 120 | + fname = file.get('name') |
| 121 | + fid = file.get('id') |
| 122 | + if fname is None or fid is None: |
| 123 | + continue |
| 124 | + file_name = local_path.joinpath(fname) |
| 125 | + if (file_name not in unique_files) or (fid > unique_files[file_name]['id']): |
| 126 | + unique_files[file_name] = {'file_info': file, 'id': fid} |
87 | 127 |
|
88 | 128 | for file_name, file_data in unique_files.items(): |
89 | 129 | file_info = file_data['file_info'] |
90 | 130 | file_id = str(file_info['id']) |
91 | | - file_url = "https://api.figshare.com/v2/file/download/" + file_id |
92 | | - file_md5sum = file_info['supplied_md5'] |
| 131 | + file_url = f"https://api.figshare.com/v2/file/download/{file_id}" |
| 132 | + file_md5sum = file_info.get('supplied_md5') |
| 133 | + |
| 134 | + if file_name.exists() and not exist_ok: |
| 135 | + warnings.warn( |
| 136 | + f"{file_name} already exists. Use argument 'exist_ok=True' to overwrite the existing file." |
| 137 | + ) |
| 138 | + |
93 | 139 | retry_count = 10 |
94 | | - # Download the file |
95 | 140 | while retry_count > 0: |
96 | | - with requests.get(file_url, stream=True) as r: |
| 141 | + with session.get(file_url, stream=True) as r: |
97 | 142 | r.raise_for_status() |
98 | | - if file_name.exists() and not exist_ok: |
99 | | - warnings.warn( |
100 | | - f"{file_name} already exists. Use argument 'exist_ok=True'" |
101 | | - "to overwrite existing file." |
102 | | - ) |
| 143 | + with open(file_name, 'wb') as f: |
| 144 | + for chunk in r.iter_content(chunk_size=8192): |
| 145 | + f.write(chunk) |
| 146 | + |
| 147 | + if file_md5sum: |
| 148 | + with open(file_name, 'rb') as f: |
| 149 | + check_md5sum = md5(f.read()).hexdigest() |
| 150 | + if file_md5sum == check_md5sum: |
| 151 | + break |
103 | 152 | else: |
104 | | - with open(file_name, 'wb') as f: |
105 | | - for chunk in r.iter_content(chunk_size=8192): |
106 | | - f.write(chunk) |
107 | | - with open(file_name, 'rb') as f: |
108 | | - check_md5sum = md5(f.read()).hexdigest() |
109 | | - if file_md5sum == check_md5sum: |
| 153 | + retry_count -= 1 |
| 154 | + if retry_count > 0: |
| 155 | + warnings.warn( |
| 156 | + f"{file_name} failed MD5 verification " |
| 157 | + f"(expected: {file_md5sum}, got: {check_md5sum}). Retrying..." |
| 158 | + ) |
| 159 | + else: |
110 | 160 | break |
111 | | - elif retry_count > 0: |
112 | | - warnings.warn( |
113 | | - f"{file_name} could not be downloaded successfully. " |
114 | | - f"(expected md5sum: {file_md5sum} - " |
115 | | - f"calculated md5sum: {check_md5sum})... retrying..." |
116 | | - ) |
117 | | - retry_count = retry_count - 1 |
118 | | - if retry_count == 0: |
| 161 | + |
| 162 | + if retry_count == 0 and file_md5sum: |
119 | 163 | warnings.warn( |
120 | | - f"{file_name} could not be downloaded. Try again." |
121 | | - ) |
| 164 | + f"{file_name} could not be downloaded with a matching MD5 after retries." |
| 165 | + ) |
122 | 166 | else: |
123 | 167 | print(f"Downloaded '{file_url}' to '{file_name}'") |
124 | 168 |
|
125 | | - return |
126 | 169 |
|
0 commit comments