Skip to content

Commit 9d17ea6

Browse files
committed
Merge branch 'rom1504-custom_dataset'
2 parents 049bd91 + 13e6230 commit 9d17ea6

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

README.md

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616

1717
### News
18+
- Thanks to [rom1504](https://github.com/rom1504) it is now easy to [train a VQGAN on your own datasets](#training-on-custom-data).
1819
- Included a bugfix for the quantizer. For backward compatibility it is
1920
disabled by default (which corresponds to always training with `beta=1.0`).
2021
Use `legacy=False` in the quantizer config to enable it.
@@ -180,6 +181,17 @@ included in the repository, run
180181
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
181182
```
182183

184+
## Training on custom data
185+
186+
Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
187+
Those are the steps to follow to make this work:
188+
1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
189+
1. put your .jpg files in a folder `your_folder`
190+
2. create 2 text files a `xx_train.txt` and `xx_test.txt` that point to the files in your training and test set respectively (for example `find $(pwd)/your_folder -name "*.jpg" > train.txt`)
191+
3. adapt `configs/custom_vqgan.yaml` to point to these 2 files
192+
4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1` to
193+
train on two GPUs. Use `--gpus 0,` (with a trailing comma) to train on a single GPU.
194+
183195
## Data Preparation
184196

185197
### ImageNet

configs/custom_vqgan.yaml

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
model:
2+
base_learning_rate: 4.5e-6
3+
target: taming.models.vqgan.VQModel
4+
params:
5+
embed_dim: 256
6+
n_embed: 1024
7+
ddconfig:
8+
double_z: False
9+
z_channels: 256
10+
resolution: 256
11+
in_channels: 3
12+
out_ch: 3
13+
ch: 128
14+
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
15+
num_res_blocks: 2
16+
attn_resolutions: [16]
17+
dropout: 0.0
18+
19+
lossconfig:
20+
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
21+
params:
22+
disc_conditional: False
23+
disc_in_channels: 3
24+
disc_start: 10000
25+
disc_weight: 0.8
26+
codebook_weight: 1.0
27+
28+
data:
29+
target: main.DataModuleFromConfig
30+
params:
31+
batch_size: 5
32+
num_workers: 8
33+
train:
34+
target: taming.data.custom.CustomTrain
35+
params:
36+
training_images_list_file: some/training.txt
37+
size: 256
38+
validation:
39+
target: taming.data.custom.CustomTest
40+
params:
41+
test_images_list_file: some/test.txt
42+
size: 256
43+

taming/data/custom.py

+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import os
2+
import numpy as np
3+
import albumentations
4+
from torch.utils.data import Dataset
5+
6+
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7+
8+
9+
class CustomBase(Dataset):
10+
def __init__(self, *args, **kwargs):
11+
super().__init__()
12+
self.data = None
13+
14+
def __len__(self):
15+
return len(self.data)
16+
17+
def __getitem__(self, i):
18+
example = self.data[i]
19+
return example
20+
21+
22+
23+
class CustomTrain(CustomBase):
24+
def __init__(self, size, training_images_list_file):
25+
super().__init__()
26+
with open(training_images_list_file, "r") as f:
27+
paths = f.read().splitlines()
28+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29+
30+
31+
class CustomTest(CustomBase):
32+
def __init__(self, size, test_images_list_file):
33+
super().__init__()
34+
with open(test_images_list_file, "r") as f:
35+
paths = f.read().splitlines()
36+
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37+
38+

0 commit comments

Comments
 (0)