Skip to content

Commit 2c52af3

Browse files
committed
lets googit add -A
0 parents  commit 2c52af3

22 files changed

+4040
-0
lines changed

.gitignore

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
pip-wheel-metadata/
24+
share/python-wheels/
25+
*.egg-info/
26+
.installed.cfg
27+
*.egg
28+
MANIFEST
29+
30+
# PyInstaller
31+
# Usually these files are written by a python script from a template
32+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
33+
*.manifest
34+
*.spec
35+
36+
# Installer logs
37+
pip-log.txt
38+
pip-delete-this-directory.txt
39+
40+
# Unit test / coverage reports
41+
htmlcov/
42+
.tox/
43+
.nox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
*.py,cover
51+
.hypothesis/
52+
.pytest_cache/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# IPython
81+
profile_default/
82+
ipython_config.py
83+
84+
# pyenv
85+
.python-version
86+
87+
# pipenv
88+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
90+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
91+
# install all needed dependencies.
92+
#Pipfile.lock
93+
94+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
95+
__pypackages__/
96+
97+
# Celery stuff
98+
celerybeat-schedule
99+
celerybeat.pid
100+
101+
# SageMath parsed files
102+
*.sage.py
103+
104+
# Environments
105+
.env
106+
.venv
107+
env/
108+
venv/
109+
ENV/
110+
env.bak/
111+
venv.bak/
112+
113+
# Spyder project settings
114+
.spyderproject
115+
.spyproject
116+
117+
# Rope project settings
118+
.ropeproject
119+
120+
# mkdocs documentation
121+
/site
122+
123+
# mypy
124+
.mypy_cache/
125+
.dmypy.json
126+
dmypy.json
127+
128+
# Pyre type checker
129+
.pyre/

Makefile

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
check_dirs := .
2+
3+
quality:
4+
black --check --preview $(check_dirs)
5+
isort --check-only $(check_dirs)
6+
flake8 $(check_dirs)
7+
8+
style:
9+
black --preview $(check_dirs)
10+
isort $(check_dirs)

README.md

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# muse-open-reproduction
2+
A repo to train the best and fastest text2image model!
3+
4+
## Goal
5+
This repo is for reproduction of the [MUSE](https://arxiv.org/abs/2301.00704) model. The goal is to create a simple and scalable repo, to reproduce MUSE and build knowedge about VQ + transformers at scale.
6+
We will use deduped LAION-2B + COYO-700M dataset for training.
7+
8+
Project stages:
9+
1. Setup the codebase and train a class-conditional model on imagenet.
10+
2. Conduct text2image experiments on CC12M.
11+
3. Train improved VQGANs models.
12+
4. Train the full (base-256) model on LAION + COYO.
13+
5. Train the full (base-512) model on LAION + COYO.
14+
15+
16+
## Steps
17+
### Setup the codebase and train a class-conditional model no imagenet.
18+
- [x] Setup repo-structure
19+
- [x] Add transformers and VQGAN model.
20+
- [x] Add a generation support for the model.
21+
- [x] Port the VQGAN from [maskgit](https://github.com/google-research/maskgit) repo for imagenet experiment.
22+
- [ ] Finish and verify masking utils.
23+
- [ ] Add the masking arccos scheduling function from MUSE.
24+
- [x] Add EMA.
25+
- [ ] Suport OmegaConf for training configuration.
26+
- [ ] Add W&B logging utils.
27+
- [ ] Add WebDataset support. Not really needed for imagenet experiment but can work on this parallelly. (LAION is already available in this format so will be easier to use it).
28+
- [ ] Add a training script for class conditional generation using imagenet. (WIP)
29+
- [ ] Make the codebase ready for the cluster training.
30+
31+
### Conduct text2image experiments on CC12M.
32+
- [ ] Finish data loading, pre-processing utils.
33+
- [ ] Add CLIP and T5 support.
34+
- [ ] Add text2image training script.
35+
- [ ] Add eavluation scripts (FiD, CLIP score).
36+
- [ ] Train on CC12M. Here we could conduct different experiments:
37+
- [ ] Train on CC12M with T5 conditioning.
38+
- [ ] Train on CC12M with CLIP conditioning.
39+
- [ ] Train on CC12M with CLIP + T5 conditioning (probably costly during training and experiments).
40+
- [ ] Self conditioning from Bit Diffusion paper.
41+
- [ ] Collect different prompts for intermmediate evaluations (Can reuse the prompts for dalle-mini, parti-prompts).
42+
- [ ] Setup a space where people can play with the model and provide feedback, compare with other models etc.
43+
44+
### Train improved VQGANs models.
45+
- [ ] Add training component models for VQGAN (EMA, discriminator, LPIPS etc).
46+
- [ ] VGQAN training script.
47+
48+
49+
### Misc tasks
50+
- [ ] Create a space for visualizing exploring dataset
51+
- [ ] Create a space where people can try to find their own images and can opt-out of the dataset.
52+
53+
54+
## Repo structure (WIP)
55+
```
56+
├── README.md
57+
├── configs -> All training config files.
58+
│ └── dummy_config.yaml
59+
├── muse
60+
│ ├── __init__.py
61+
│ ├── data.py -> All data related utils. Can create a data folder if needed.
62+
│ ├── logging.py -> Misc logging utils.
63+
│ ├── maskgit_vqgan.py -> VQGAN model from maskgit repo.
64+
│ ├── modeling_utils.py -> All model related utils, like save_pretrained, from_pretrained from hub etc
65+
│ ├── sampling.py -> Sampling/Generation utils.
66+
│ ├── taming_vqgan.py -> VQGAN model from taming repo.
67+
│ ├── training_utils.py -> Common training utils.
68+
│ └── transformer.py -> The main transformer model.
69+
├── pyproject.toml
70+
├── setup.cfg
71+
├── setup.py
72+
├── test.py
73+
└── training -> All training scripts.
74+
├── train_muse.py
75+
└── train_vqgan.py

configs/dummy_config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# A dir to store training configurations

configs/imagenet_test.yaml

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
experiment:
2+
name: "imagenet"
3+
project: "muse"
4+
output_dir: "imagenet"
5+
max_train_examples: 110000000
6+
num_eval_images: 1000
7+
save_every: 1000
8+
log_every: 50
9+
10+
11+
model:
12+
vq_model:
13+
pretrained: "path to vq model"
14+
15+
transformer:
16+
vocab_size: 2025 # (1024 + 1000 + 1 -> Vq + Imagenet + <mask>)
17+
hidden_size: 64
18+
num_hidden_layers: 2
19+
num_attention_heads: 4
20+
intermediate_size: 256
21+
hidden_dropout: 0.1
22+
attention_dropout: 0.1
23+
max_position_embeddings: 256
24+
initializer_range: 0.02
25+
layer_norm_eps: 1e-6
26+
use_bias: False
27+
28+
gradient_checkpointing: True
29+
30+
31+
dataset:
32+
params:
33+
path: "imagenet-1k-" # path to imagenet dataset
34+
streaming: True
35+
shuffle_buffer_size: 5000
36+
batch_size: ${training.batch_size}
37+
workers: 1
38+
class_mapping: "scripts/metadata/imagenet_idx_to_prompt.json"
39+
resolution: 256
40+
preprocessing:
41+
resolution: 256
42+
center_crop: True
43+
random_flip: True
44+
45+
46+
optimizer:
47+
name: adamw
48+
params:
49+
learning_rate: 0.0001
50+
beta1: 0.9
51+
beta2: 0.98
52+
weight_decay: 0.01
53+
epsilon: 0.00000001
54+
55+
56+
lr_scheduler:
57+
scheduler: "ConstantWithWarmup"
58+
params:
59+
learning_rate: ${optimizer.params.learning_rate}
60+
warmup_steps: 500
61+
62+
63+
training:
64+
gradient_accumulation_steps: 1
65+
batch_size: 16
66+
mixed_precision: bf16
67+
use_ema: False
68+
seed: 42
69+
max_train_steps: 1000
70+
num_epochs: 100

muse/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__version__ = "0.0.1"
2+
3+
from .maskgit_vqgan import MaskGitVQGAN
4+
from .taming_vqgan import VQGANModel
5+
from .transformer import MaskGitTransformer

muse/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""All data related utilities and loaders are defined here."""

0 commit comments

Comments
 (0)