Skip to content

Commit 8ded673

Browse files
committed
vqgan-jax
0 parents  commit 8ded673

4 files changed

+913
-0
lines changed

.gitignore

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
# Initially taken from Github's Python gitignore file
2+
3+
# Byte-compiled / optimized / DLL files
4+
__pycache__/
5+
*.py[cod]
6+
*$py.class
7+
8+
# C extensions
9+
*.so
10+
11+
# tests and logs
12+
tests/fixtures/cached_*_text.txt
13+
logs/
14+
lightning_logs/
15+
lang_code_data/
16+
17+
# Distribution / packaging
18+
.Python
19+
build/
20+
develop-eggs/
21+
dist/
22+
downloads/
23+
eggs/
24+
.eggs/
25+
lib/
26+
lib64/
27+
parts/
28+
sdist/
29+
var/
30+
wheels/
31+
*.egg-info/
32+
.installed.cfg
33+
*.egg
34+
MANIFEST
35+
36+
# PyInstaller
37+
# Usually these files are written by a python script from a template
38+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
39+
*.manifest
40+
*.spec
41+
42+
# Installer logs
43+
pip-log.txt
44+
pip-delete-this-directory.txt
45+
46+
# Unit test / coverage reports
47+
htmlcov/
48+
.tox/
49+
.nox/
50+
.coverage
51+
.coverage.*
52+
.cache
53+
nosetests.xml
54+
coverage.xml
55+
*.cover
56+
.hypothesis/
57+
.pytest_cache/
58+
59+
# Translations
60+
*.mo
61+
*.pot
62+
63+
# Django stuff:
64+
*.log
65+
local_settings.py
66+
db.sqlite3
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# IPython
85+
profile_default/
86+
ipython_config.py
87+
88+
# pyenv
89+
.python-version
90+
91+
# celery beat schedule file
92+
celerybeat-schedule
93+
94+
# SageMath parsed files
95+
*.sage.py
96+
97+
# Environments
98+
.env
99+
.venv
100+
env/
101+
venv/
102+
ENV/
103+
env.bak/
104+
venv.bak/
105+
106+
# Spyder project settings
107+
.spyderproject
108+
.spyproject
109+
110+
# Rope project settings
111+
.ropeproject
112+
113+
# mkdocs documentation
114+
/site
115+
116+
# mypy
117+
.mypy_cache/
118+
.dmypy.json
119+
dmypy.json
120+
121+
# Pyre type checker
122+
.pyre/
123+
124+
# vscode
125+
.vs
126+
.vscode
127+
128+
# Pycharm
129+
.idea
130+
131+
# TF code
132+
tensorflow_code
133+
134+
# Models
135+
proc_data
136+
137+
# data
138+
/data
139+
serialization_dir
140+
141+
# emacs
142+
*.*~
143+
debug.env
144+
145+
# vim
146+
.*.swp
147+
148+
#ctags
149+
tags
150+
151+
# pre-commit
152+
.pre-commit*
153+
154+
# .lock
155+
*.lock

configuration_vqgan.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Tuple
2+
3+
from transformers import PretrainedConfig
4+
5+
6+
class VQGANConfig(PretrainedConfig):
7+
def __init__(
8+
self,
9+
ch: int = 128,
10+
out_ch: int = 3,
11+
in_channels: int = 3,
12+
num_res_blocks: int = 2,
13+
resolution: int = 256,
14+
z_channels: int = 256,
15+
ch_mult: Tuple = (1, 1, 2, 2, 4),
16+
attn_resolutions: int = (16,),
17+
n_embed: int = 1024,
18+
embed_dim: int = 256,
19+
dropout: float = 0.0,
20+
double_z: bool = False,
21+
resamp_with_conv: bool = True,
22+
give_pre_end: bool = False,
23+
**kwargs,
24+
):
25+
super().__init__(**kwargs)
26+
self.ch = ch
27+
self.out_ch = out_ch
28+
self.in_channels = in_channels
29+
self.num_res_blocks = num_res_blocks
30+
self.resolution = resolution
31+
self.z_channels = z_channels
32+
self.ch_mult = list(ch_mult)
33+
self.attn_resolutions = list(attn_resolutions)
34+
self.n_embed = n_embed
35+
self.embed_dim = embed_dim
36+
self.dropout = dropout
37+
self.double_z = double_z
38+
self.resamp_with_conv = resamp_with_conv
39+
self.give_pre_end = give_pre_end
40+
self.num_resolutions = len(ch_mult)

convert_pt_model_to_jax.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import re
2+
3+
import jax.numpy as jnp
4+
from flax.traverse_util import flatten_dict, unflatten_dict
5+
6+
import torch
7+
8+
from modeling_flax_vqgan import VQModel
9+
from configuration_vqgan import VQGANConfig
10+
11+
12+
regex = r"\w+[.]\d+"
13+
14+
15+
def rename_key(key):
16+
pats = re.findall(regex, key)
17+
for pat in pats:
18+
key = key.replace(pat, "_".join(pat.split(".")))
19+
return key
20+
21+
22+
# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
23+
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
24+
# convert pytorch tensor to numpy
25+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
26+
27+
random_flax_state_dict = flatten_dict(flax_model.params)
28+
flax_state_dict = {}
29+
30+
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
31+
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
32+
)
33+
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
34+
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
35+
)
36+
37+
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
38+
for pt_key, pt_tensor in pt_state_dict.items():
39+
pt_tuple_key = tuple(pt_key.split("."))
40+
41+
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
42+
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
43+
44+
if remove_base_model_prefix and has_base_model_prefix:
45+
pt_tuple_key = pt_tuple_key[1:]
46+
elif add_base_model_prefix and require_base_model_prefix:
47+
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
48+
49+
# Correctly rename weight parameters
50+
if (
51+
"norm" in pt_key
52+
and (pt_tuple_key[-1] == "bias")
53+
and (pt_tuple_key[:-1] + ("bias",) in random_flax_state_dict)
54+
):
55+
pt_tensor = pt_tensor[None, None, None, :]
56+
elif (
57+
"norm" in pt_key
58+
and (pt_tuple_key[-1] == "bias")
59+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
60+
):
61+
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
62+
pt_tensor = pt_tensor[None, None, None, :]
63+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
64+
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
65+
pt_tensor = pt_tensor[None, None, None, :]
66+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
67+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
68+
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
69+
# conv layer
70+
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
71+
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
72+
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
73+
# linear layer
74+
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
75+
pt_tensor = pt_tensor.T
76+
elif pt_tuple_key[-1] == "gamma":
77+
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
78+
elif pt_tuple_key[-1] == "beta":
79+
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
80+
81+
if pt_tuple_key in random_flax_state_dict:
82+
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
83+
raise ValueError(
84+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
85+
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
86+
)
87+
88+
# also add unexpected weight so that warning is thrown
89+
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
90+
91+
return unflatten_dict(flax_state_dict)
92+
93+
94+
def convert_model(config_path, pt_state_dict_path, save_path):
95+
config = VQGANConfig.from_pretrained(config_path)
96+
model = VQModel(config)
97+
98+
state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
99+
keys = list(state_dict.keys())
100+
for key in keys:
101+
if key.startswith("loss"):
102+
state_dict.pop(key)
103+
continue
104+
renamed_key = rename_key(key)
105+
state_dict[renamed_key] = state_dict.pop(key)
106+
107+
state = convert_pytorch_state_dict_to_flax(state_dict, model)
108+
model.params = unflatten_dict(state)
109+
model.save_pretrained(save_path)

0 commit comments

Comments
 (0)