Skip to content

Commit f895e62

Browse files
wesleyw72mingyuliutw
authored andcommitted
Scripted pretrained model download (NVIDIA#24)
* Script to download models * Model download bash script, usage explained
1 parent 7508715 commit f895e62

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

USAGE.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,14 @@ These models are extracted from Torch7 models and currently used in the project.
3939

4040
**Original Torch7 models**
4141

42+
Manually download the model files.
4243
- Download pretrained networks via the following [link](https://drive.google.com/open?id=1ENgQm9TgabE1R99zhNf5q6meBvX6WFuq).
4344
- Unzip and store the model files under `models`.
4445

46+
Automatically downloads pretrained networks and unzips them.
47+
- Requires requests (`pip install requests`)
48+
- `bash download_models.sh`
49+
4550
`converter.py` shows how to convert Torch7 models to PyTorch models.
4651

4752
### Example 1: Transfer the style of a style photo to a content photo.

download_models.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Download code taken from Code taken from https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039
2+
import requests
3+
4+
def download_file_from_google_drive(id, destination):
5+
URL = "https://docs.google.com/uc?export=download"
6+
7+
session = requests.Session()
8+
9+
response = session.get(URL, params = { 'id' : id }, stream = True)
10+
token = get_confirm_token(response)
11+
12+
if token:
13+
params = { 'id' : id, 'confirm' : token }
14+
response = session.get(URL, params = params, stream = True)
15+
16+
save_response_content(response, destination)
17+
18+
def get_confirm_token(response):
19+
for key, value in response.cookies.items():
20+
if key.startswith('download_warning'):
21+
return value
22+
23+
return None
24+
25+
def save_response_content(response, destination):
26+
CHUNK_SIZE = 32768
27+
28+
with open(destination, "wb") as f:
29+
for chunk in response.iter_content(CHUNK_SIZE):
30+
if chunk: # filter out keep-alive new chunks
31+
f.write(chunk)
32+
33+
file_id = '1ENgQm9TgabE1R99zhNf5q6meBvX6WFuq'
34+
destination = './models.zip'
35+
download_file_from_google_drive(file_id, destination)

download_models.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#!/bin/bash
2+
python download_models.py
3+
unzip models.zip

0 commit comments

Comments
 (0)