Skip to content

Commit 6fd2003

Browse files
committed
adding notebook and APT installation file
1 parent 6cbdbc2 commit 6fd2003

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

apt.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
unzip

notebooks/demo.ipynb

+178
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Overview\n",
8+
"The default setup downloads and decompresses the model \n",
9+
"```bash\n",
10+
"curl -L \"https://www.dropbox.com/s/dlmpr7wabehq3x0/models.zip?dl=1\" > models.zip\n",
11+
"unzip models.zip\n",
12+
"```"
13+
]
14+
},
15+
{
16+
"cell_type": "code",
17+
"execution_count": 1,
18+
"metadata": {},
19+
"outputs": [],
20+
"source": [
21+
"import os\n",
22+
"import sys\n",
23+
"sys.path.append('../')"
24+
]
25+
},
26+
{
27+
"cell_type": "code",
28+
"execution_count": 3,
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"import torch\n",
33+
"from torch.autograd import Variable\n",
34+
"import torchvision.transforms as transforms\n",
35+
"import torchvision.utils as utils\n",
36+
"import argparse\n",
37+
"import time\n",
38+
"import numpy as np\n",
39+
"import cv2\n",
40+
"from PIL import Image\n",
41+
"from photo_wct import PhotoWCT\n",
42+
"from photo_smooth import Propagator\n",
43+
"#from smooth_filter import smooth_filter"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": 10,
49+
"metadata": {},
50+
"outputs": [],
51+
"source": [
52+
"from collections import namedtuple\n",
53+
"# prepare paths for model\n",
54+
"vgg_paths = {}\n",
55+
"for i in range(1,6):\n",
56+
" vgg_paths['vgg{}'.format(i)] = '../models/vgg_normalised_conv{}_1_mask.t7'.format(i) \n",
57+
" vgg_paths['decoder{}'.format(i)] = '../models/feature_invertor_conv{}_1_mask.t7'.format(i)\n",
58+
"\n",
59+
"vgg_key_list = list(vgg_paths.keys())\n",
60+
"vgg_arg_class = namedtuple('VggArgs', vgg_key_list)\n",
61+
"vgg_args = vgg_arg_class(*[vgg_paths[k] for k in vgg_key_list])"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 11,
67+
"metadata": {},
68+
"outputs": [
69+
{
70+
"ename": "IOError",
71+
"evalue": "[Errno 2] No such file or directory: '../models/vgg_normalised_conv1_1_mask.t7'",
72+
"output_type": "error",
73+
"traceback": [
74+
"\u001b[0;31m\u001b[0m",
75+
"\u001b[0;31mIOError\u001b[0mTraceback (most recent call last)",
76+
"\u001b[0;32m<ipython-input-11-aed45ba8830c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# Load model\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mp_wct\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPhotoWCT\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvgg_args\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0mp_pro\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPropagator\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
77+
"\u001b[0;32m/home/jovyan/photo_wct.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, args)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mPhotoWCT\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;31m# TODO: convert these torch models to pytorch models.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mvgg1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_lua\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvgg1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0mdecoder1_torch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_lua\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0mvgg2\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mload_lua\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvgg2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
78+
"\u001b[0;32m/srv/conda/envs/kernel/lib/python2.7/site-packages/torch/utils/serialization/read_lua_file.pyc\u001b[0m in \u001b[0;36mload_lua\u001b[0;34m(filename, **kwargs)\u001b[0m\n\u001b[1;32m 595\u001b[0m \u001b[0mto\u001b[0m \u001b[0;34m`\u001b[0m\u001b[0mT7Reader\u001b[0m\u001b[0;34m`\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 596\u001b[0m \"\"\"\n\u001b[0;32m--> 597\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilename\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 598\u001b[0m \u001b[0mreader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mT7Reader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 599\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mreader\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
79+
"\u001b[0;31mIOError\u001b[0m: [Errno 2] No such file or directory: '../models/vgg_normalised_conv1_1_mask.t7'"
80+
]
81+
}
82+
],
83+
"source": [
84+
"# Load model\n",
85+
"p_wct = PhotoWCT(vgg_args)\n",
86+
"p_pro = Propagator()"
87+
]
88+
},
89+
{
90+
"cell_type": "markdown",
91+
"metadata": {},
92+
"source": [
93+
"# Style and Content Images\n",
94+
"## Content\n",
95+
"![](../images/content1.png)\n",
96+
"## Style\n",
97+
"![](../images/style1.png)"
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": 14,
103+
"metadata": {},
104+
"outputs": [],
105+
"source": [
106+
"content_image_path = \"../images/content1.png\"\n",
107+
"content_seg_path = []\n",
108+
"style_image_path = \"../images/style1.png\"\n",
109+
"style_seg_path = []\n",
110+
"output_image_path = \"../results/example1.png\""
111+
]
112+
},
113+
{
114+
"cell_type": "code",
115+
"execution_count": 15,
116+
"metadata": {},
117+
"outputs": [
118+
{
119+
"ename": "NameError",
120+
"evalue": "name 'p_wct' is not defined",
121+
"output_type": "error",
122+
"traceback": [
123+
"\u001b[0;31m\u001b[0m",
124+
"\u001b[0;31mNameError\u001b[0mTraceback (most recent call last)",
125+
"\u001b[0;32m<ipython-input-15-8e5a025d7c26>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mp_wct\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m# Load image\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mcont_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent_image_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'RGB'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mstyl_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mImage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstyle_image_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconvert\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'RGB'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
126+
"\u001b[0;31mNameError\u001b[0m: name 'p_wct' is not defined"
127+
]
128+
}
129+
],
130+
"source": [
131+
"p_wct.cuda(0)\n",
132+
"\n",
133+
"# Load image\n",
134+
"cont_img = Image.open(content_image_path).convert('RGB')\n",
135+
"styl_img = Image.open(style_image_path).convert('RGB')\n",
136+
"try:\n",
137+
" cont_seg = Image.open(content_seg_path)\n",
138+
" styl_seg = Image.open(style_seg_path)\n",
139+
"except:\n",
140+
" cont_seg = []\n",
141+
" styl_seg = []\n",
142+
"\n",
143+
"cont_img = transforms.ToTensor()(cont_img).unsqueeze(0)\n",
144+
"styl_img = transforms.ToTensor()(styl_img).unsqueeze(0)\n",
145+
"cont_img = Variable(cont_img.cuda(0), volatile=True)\n",
146+
"styl_img = Variable(styl_img.cuda(0), volatile=True)"
147+
]
148+
},
149+
{
150+
"cell_type": "code",
151+
"execution_count": null,
152+
"metadata": {},
153+
"outputs": [],
154+
"source": []
155+
}
156+
],
157+
"metadata": {
158+
"kernelspec": {
159+
"display_name": "Python 2",
160+
"language": "python",
161+
"name": "python2"
162+
},
163+
"language_info": {
164+
"codemirror_mode": {
165+
"name": "ipython",
166+
"version": 2
167+
},
168+
"file_extension": ".py",
169+
"mimetype": "text/x-python",
170+
"name": "python",
171+
"nbconvert_exporter": "python",
172+
"pygments_lexer": "ipython2",
173+
"version": "2.7.14"
174+
}
175+
},
176+
"nbformat": 4,
177+
"nbformat_minor": 2
178+
}

0 commit comments

Comments
 (0)