|
| 1 | +# Mini Continuous Diffusion From Categorical Data |
| 2 | + |
| 3 | +This repository aims to reproduce the [Continuous Diffusion from Categorical Data paper by Dieleman et al](https://arxiv.org/pdf/2211.15089.pdf) where the authors managed to generate coherent text using a non-autoregressive diffusion model. |
| 4 | + |
| 5 | +It is inspired by Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT) where he was able to generate coherent text with ~100M parameters. |
| 6 | + |
| 7 | + |
| 8 | + |
| 9 | +## The Goal |
| 10 | + |
| 11 | +The goal of this repository is to give the simplest possible reproduction of the paper. Here are some choices we made to make things simple |
| 12 | + |
| 13 | +- The source code is < 500 lines of code |
| 14 | +- We trained models ranging from 500k~100M parameters |
| 15 | +- The dataset used is [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) (~1Gb of data) |
| 16 | +- During the noising process the noise is added to all the tokens |
| 17 | +- The tokenizer used is the BERT tokenizer (~30k vocab size) |
| 18 | +- No self-conditioning |
| 19 | +- No wierd ODE solvers. Euler is enough |
| 20 | + |
| 21 | +# Results |
| 22 | +Here is the output of a 64 tokens generation of a ~600k parameter model trained on a RTX 3090 for ~3 min |
| 23 | + |
| 24 | +>[CLS] once upon was time, he was a rabbit bell to visit his lid. he knocked, there wanted to run a man of his prohibits. one day, the mommy day, brown look airport other, he. dark, and where'this t to careful molly when she on an book it kept and smiled course [SEP] |
| 25 | +
|
| 26 | +And here is the output of a 128 tokens generation of a ~140k parameter model trained on a H100 for ~1 day, however this can be significantly improved as we didn't bother tuning any hyperparameter |
| 27 | +>[CLS] one day, tom called tommy who loved had a house with park. her liked when living cook over the garden fun walking and and small but said out. mommy teach,ggles smiled run weeping was a whileyfixed. as, swimming stuffing flew to sock machine watch fast went good house. but is his moving his offer but each rolled. as smiled my it he and said, it then max said it tom arrived ta sock! the frog was found a noise for in the tree he he tapping a piece piece of anyway and could read he dodge throw and lots around the hole. jen you enjoyed to! the floor and then both [SEP] |
| 28 | +
|
| 29 | +**_Note:_** the results can be improved with more compute, data, self-conditioning, better ODE-solvers and so on, but for the sake of this repository this is a win. |
| 30 | + |
| 31 | +### Noise scheduling |
| 32 | +For the noise scheduling they use use a linear schedule $\sigma(t)=t$ just as explained in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/pdf/2206.00364.pdf) |
| 33 | + |
| 34 | +For sampling the timesteps in the [CDCD paper](https://arxiv.org/pdf/2211.15089.pdf) they use a monotonic [piece-wise linear function](https://en.wikipedia.org/wiki/Piecewise_linear_function) to fit the model prediction entropy $S$ as a function of the time $t$ and use it as a unormalized Cumulative Density Function (CDF) $F(t)$ |
| 35 | + |
| 36 | +We instead fit $F(t)$ with a [Cauchy-like](https://en.wikipedia.org/wiki/Cauchy_distribution) cumulative distribution function. It is simpler, more flexible and efficient. Overall it's just better. |
| 37 | + |
| 38 | + |
| 39 | +### Preconditioning |
| 40 | + |
| 41 | +In [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/pdf/2206.00364.pdf) by Karras et al. and they define the output of the model $D_\theta(\boldsymbol x,\sigma)$ as following (eq. 7 of the paper) |
| 42 | + |
| 43 | +$$D_\theta(\boldsymbol x,\sigma)=c_\textrm{skip}(\sigma)\boldsymbol x + c_\textrm{out}(\sigma)F_\theta(c_\textrm{in}(\sigma)\boldsymbol x,c_\textrm{noise}(\sigma))$$ |
| 44 | + |
| 45 | +Where $F_\theta(\cdot)$ is the the actual Transformer and $c_\textrm{skip},c_\textrm{out},c_\textrm{in},c_\textrm{noise}$ are non-trainable modulation functions |
| 46 | + |
| 47 | +|modulation |Karras |CDCD |ours | |
| 48 | +|---|---|---|---| |
| 49 | +|$c_\textrm{skip}(\sigma)$ | $1/ (1+\sigma^2)$| ? | $0$ | |
| 50 | +|$c_\textrm{out}(\sigma)$ | $\sigma/\sqrt{1+\sigma^2}$ | ? | $1$ | |
| 51 | +|$c_\textrm{in}(\sigma)$ | $1/\sqrt{1+\sigma^2}$ | $1/\sqrt{1+\sigma^2}$ |$1/\sqrt{1+\sigma^2}$ | |
| 52 | +|$c_\textrm{noise}(\sigma)$ | $\ln(\sigma)/4$ | ? | $\ln(\sigma)/4$ | |
| 53 | +> Sources: [Details in section 6.1 of the CDCD paper](https://arxiv.org/pdf/2211.15089.pdf) and [table 1 of Karras paper](https://arxiv.org/pdf/2206.00364.pdf) |
| 54 | +> Note: Any discrepancies with the Karras paper are due to the fact that we have $\sigma_\textrm{data}=1$ because on how we initialize the input embeddings. |
| 55 | +
|
| 56 | +**_Important Note_** |
| 57 | +We found that the choice of the modulation function has a big effect on the outcome of the training |
| 58 | + |
| 59 | +# Training |
| 60 | +```bash |
| 61 | +pip install -r requirements.txt |
| 62 | +composer training.py |
| 63 | +``` |
| 64 | +alternatively a equivalent but slower and more detailed training loop is available in the [`training.ipynb`](https://github.com/markov-bio/cdcd/blob/master/training.ipynb) notebook. Here is a quick explanation of what it does |
| 65 | + |
| 66 | +The first cell has to do with downloading the dataset and the tokenizer |
| 67 | +```python |
| 68 | +dataset = load_dataset("roneneldan/TinyStories") |
| 69 | +tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") # or any suitable tokenizer |
| 70 | +[... other code ...] |
| 71 | +``` |
| 72 | + |
| 73 | +The second cell has to do with defining the model |
| 74 | +```python |
| 75 | + |
| 76 | +model=DiffusionModel(embed_dim,hidden_dim,qkv_dim,num_heads,cond_dim,n_blocks,tokenizer,p_self_cond,p_mask_cond,p_mask,prefix) |
| 77 | + |
| 78 | +``` |
| 79 | + |
| 80 | +Third cell has to do with defining the optimizer |
| 81 | +```python |
| 82 | +optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4) |
| 83 | +lr_scheduler = [...] |
| 84 | +``` |
| 85 | + |
| 86 | +The fourth cell has the training loop |
| 87 | +```python |
| 88 | +for epoch in range(num_epochs): |
| 89 | + for i,tokens in enumerate(train_loader): |
| 90 | + |
| 91 | + optimizer.zero_grad() |
| 92 | + tokens = batch['input_ids'].to(device) |
| 93 | + prediction=model(tokens) |
| 94 | + |
| 95 | + loss = model.loss(prediction,tokens) |
| 96 | + loss.backward() |
| 97 | + optimizer.step() |
| 98 | + |
| 99 | + # Log, print, or save as needed |
| 100 | + if i%schedule_update_frequency==0 and i!=0: |
| 101 | + model.noise_schedule.update_optimal_parameters() |
| 102 | + |
| 103 | + if i%50==0 and i!=0: |
| 104 | + lr_scheduler.step() |
| 105 | + model.noise_schedule.plot_entropy_time_curve() |
| 106 | +``` |
| 107 | +And you should the most recent datapoints along with the last best-fit for the Unormalized Cumulative Density Function $F(t)$ |
| 108 | + |
| 109 | +It represents the crossentropy loss of the model as a function of the noise $\sigma$ added. The more recent datapoints are colored darker. |
| 110 | +The blue curve represents the fit of $F(t)$ (learnt unormalized CDF). |
| 111 | + |
| 112 | +The other curve that shows up is the one that represents how the best fit for $F(t)$ improves as the training progresses |
| 113 | + |
| 114 | +The more recent best-fitss are colored darker. |
| 115 | +As the curve shift to the right is idicates that it is learning how to denoise the signal better and better |
| 116 | + |
| 117 | +### Comparison of the result with the CDCD paper |
| 118 | +Checking with a ruler it seems that the curve obtained in our experiment is pretty much identical to the one obtained by the autors in the figure 2 of the CDCD paper |
| 119 | + |
| 120 | + |
| 121 | +# Pseudocode for Score interpolation |
| 122 | +Since in the original paper there is not any code explanation for the score interpolation here it is: |
| 123 | + |
| 124 | +--- |
| 125 | + |
| 126 | +**Generation**$(D_{\theta}(x;t)$, $e_{j\in \{0,\ldots,V-1\}}$, $t_\textrm{max},t_\textrm{min}, N)$ |
| 127 | + |
| 128 | + |
| 129 | +1. $S_i\gets \textrm {Uniform}(F(t_\textrm{max}),F(t_\textrm{min}), N)$ // Generate $N$ uniformly distributed samples $S_i$ between $F(t_\text{max})$ |
| 130 | +2. $t_i \leftarrow F^{-1}(S_i)$ // Inverse transform sampling to get times |
| 131 | +3. $x_0 \sim \mathcal{N}(0, t_0^2 I)$ // Initialize $x_0$ with noise based on max time variance |
| 132 | +4. **For** $i \in \{0,\dots, N-1\}$ **do**: |
| 133 | + - $\hat x_0 \leftarrow D_{\theta}(x_i; t_i)$ // Apply model to estimate completely denoised image $\hat x_0$ |
| 134 | + - $p_j(\hat x_0) \leftarrow \text{Softmax}(\hat x_0 \cdot e_j)$ // Softmax to get probabilities of embeddings |
| 135 | + - $\mathbf E_{p} [\hat x_0] \leftarrow \sum_{j}e_jp_j(\hat x_0)$ // Calculate expected embedding |
| 136 | + - $d_i \leftarrow \frac{x_i - \mathbf E_{p} [\hat x_0]}{t_i}$ // Compute derivative |
| 137 | + - $x_{i+1} \leftarrow x_i + (t_{i+1} - t_i) d_i$ // Euler step for next sample |
| 138 | +5. **Return** $x_N$ // return generated sample |
0 commit comments