-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel_downloader.py
94 lines (76 loc) · 3.3 KB
/
model_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
"""
Downloads models required for nasbench so they can be used by example.py
Automatically deletes the zip folders after downloading
Note: 'string {}'.format(arg) used to keep backward compatibility
"""
import os
import sys
from zipfile import ZipFile
import requests
from tqdm import tqdm
# Download URL
# Note: Update if the download location changes
# URL_MODELS_0_9 = 'https://ndownloader.figshare.com/files/21981038'
URL_MODELS_0_9 = 'https://ndownloader.figshare.com/files/40109821'
def download(url, path):
# Taken from: https://stackoverflow.com/a/37573701
response = requests.get(url, stream=True)
# Measured in Bytes
file_size = int(response.headers.get('content_length', 0))
block_size = 1024
progress_bar = tqdm(total=file_size, unit='iB', unit_scale=True)
# Write the download to file
with open(path, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if file_size not in (0, progress_bar.n):
print('Error downloading from {}'.format(url))
sys.exit(1)
def download_models(version, delete_zip=True,
download_dir=os.getcwd()):
# Create paths and names
download_url = URL_MODELS_0_9
zip_filename = 'models_{}.zip'.format(version)
models_folder = 'anb_models_0_9'# 'anb_models_{}'.format(version)
current_dir = download_dir
zip_path = os.path.join(current_dir, zip_filename)
models_dir = os.path.join(current_dir, models_folder)
# Check if already exists
if os.path.exists(models_dir):
print('Models {} already at {}'.format(version, models_dir))
else:
print('Downloading models {} from {} to {}'.format(version,
download_url,
zip_path))
download(download_url, zip_path)
# Zip contains a folder called 'nb_models' so we just unzip
# it to the current dir and then rename it to give it a version
print('Extracting {} to {}'.format(zip_filename, models_dir))
with ZipFile(zip_path, 'r') as zipfile:
zipfile.extractall(current_dir)
unzipped_folder_name = os.path.join(current_dir, 'anb_models_0_9')
os.rename(unzipped_folder_name, models_dir)
# Finally, remove the zip
# If the library is used by a different library, these zips
# would end up taking space in the virtual env where the libray user
# is unlikely to know they even exists there taking up space
if delete_zip:
print('Deleting downloaded zip at {}'.format(zip_path))
os.remove(zip_path)
if __name__ == "__main__":
# Parse args
# Note: Would probably be easier to use a lib for this
# Also doesn't give a download arg this way
version = '0.9' # default to use 1.0
if len(sys.argv) == 2:
if version not in ('0.9', '1.0'):
print('Usage: python {} {}'.format(sys.argv[0], '[0.9 | 1.0]'))
sys.exit(1)
else:
version = '0.9' if sys.argv[1] == '0.9' else '1.0'
elif len(sys.arv) > 2:
print('Usage: python {} {}'.format(sys.argv[0], '[0.9 | 1.0]'))
sys.exit(1)
download_models(version, delete_zip=True)