Skip to content

Commit 91dacf9

Browse files
committed
initial commit
0 parents  commit 91dacf9

36 files changed

+1390
-0
lines changed

.DS_Store

6 KB
Binary file not shown.

README.md

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Mini Continuous Diffusion From Categorical Data
2+
3+
This repository aims to reproduce the [Continuous Diffusion from Categorical Data paper by Dieleman et al](https://arxiv.org/pdf/2211.15089.pdf) where the authors managed to generate coherent text using a non-autoregressive diffusion model.
4+
5+
It is inspired by Karpathy's [nanoGPT](https://github.com/karpathy/nanoGPT) where he was able to generate coherent text with ~100M parameters.
6+
7+
8+
9+
## The Goal
10+
11+
The goal of this repository is to give the simplest possible reproduction of the paper. Here are some choices we made to make things simple
12+
13+
- The source code is < 500 lines of code
14+
- We trained models ranging from 500k~100M parameters
15+
- The dataset used is [TinyStories](https://huggingface.co/datasets/roneneldan/TinyStories) (~1Gb of data)
16+
- During the noising process the noise is added to all the tokens
17+
- The tokenizer used is the BERT tokenizer (~30k vocab size)
18+
- No self-conditioning
19+
- No wierd ODE solvers. Euler is enough
20+
21+
# Results
22+
Here is the output of a 64 tokens generation of a ~600k parameter model trained on a RTX 3090 for ~3 min
23+
24+
>[CLS] once upon was time, he was a rabbit bell to visit his lid. he knocked, there wanted to run a man of his prohibits. one day, the mommy day, brown look airport other, he. dark, and where'this t to careful molly when she on an book it kept and smiled course [SEP]
25+
26+
And here is the output of a 128 tokens generation of a ~140k parameter model trained on a H100 for ~1 day, however this can be significantly improved as we didn't bother tuning any hyperparameter
27+
>[CLS] one day, tom called tommy who loved had a house with park. her liked when living cook over the garden fun walking and and small but said out. mommy teach,ggles smiled run weeping was a whileyfixed. as, swimming stuffing flew to sock machine watch fast went good house. but is his moving his offer but each rolled. as smiled my it he and said, it then max said it tom arrived ta sock! the frog was found a noise for in the tree he he tapping a piece piece of anyway and could read he dodge throw and lots around the hole. jen you enjoyed to! the floor and then both [SEP]
28+
29+
**_Note:_** the results can be improved with more compute, data, self-conditioning, better ODE-solvers and so on, but for the sake of this repository this is a win.
30+
31+
### Noise scheduling
32+
For the noise scheduling they use use a linear schedule $\sigma(t)=t$ just as explained in [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/pdf/2206.00364.pdf)
33+
34+
For sampling the timesteps in the [CDCD paper](https://arxiv.org/pdf/2211.15089.pdf) they use a monotonic [piece-wise linear function](https://en.wikipedia.org/wiki/Piecewise_linear_function) to fit the model prediction entropy $S$ as a function of the time $t$ and use it as a unormalized Cumulative Density Function (CDF) $F(t)$
35+
36+
We instead fit $F(t)$ with a [Cauchy-like](https://en.wikipedia.org/wiki/Cauchy_distribution) cumulative distribution function. It is simpler, more flexible and efficient. Overall it's just better.
37+
38+
39+
### Preconditioning
40+
41+
In [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/pdf/2206.00364.pdf) by Karras et al. and they define the output of the model $D_\theta(\boldsymbol x,\sigma)$ as following (eq. 7 of the paper)
42+
43+
$$D_\theta(\boldsymbol x,\sigma)=c_\textrm{skip}(\sigma)\boldsymbol x + c_\textrm{out}(\sigma)F_\theta(c_\textrm{in}(\sigma)\boldsymbol x,c_\textrm{noise}(\sigma))$$
44+
45+
Where $F_\theta(\cdot)$ is the the actual Transformer and $c_\textrm{skip},c_\textrm{out},c_\textrm{in},c_\textrm{noise}$ are non-trainable modulation functions
46+
47+
|modulation |Karras |CDCD |ours |
48+
|---|---|---|---|
49+
|$c_\textrm{skip}(\sigma)$ | $1/ (1+\sigma^2)$| ? | $0$ |
50+
|$c_\textrm{out}(\sigma)$ | $\sigma/\sqrt{1+\sigma^2}$ | ? | $1$ |
51+
|$c_\textrm{in}(\sigma)$ | $1/\sqrt{1+\sigma^2}$ | $1/\sqrt{1+\sigma^2}$ |$1/\sqrt{1+\sigma^2}$ |
52+
|$c_\textrm{noise}(\sigma)$ | $\ln(\sigma)/4$ | ? | $\ln(\sigma)/4$ |
53+
> Sources: [Details in section 6.1 of the CDCD paper](https://arxiv.org/pdf/2211.15089.pdf) and [table 1 of Karras paper](https://arxiv.org/pdf/2206.00364.pdf)
54+
> Note: Any discrepancies with the Karras paper are due to the fact that we have $\sigma_\textrm{data}=1$ because on how we initialize the input embeddings.
55+
56+
**_Important Note_**
57+
We found that the choice of the modulation function has a big effect on the outcome of the training
58+
59+
# Training
60+
```bash
61+
pip install -r requirements.txt
62+
composer training.py
63+
```
64+
alternatively a equivalent but slower and more detailed training loop is available in the [`training.ipynb`](https://github.com/markov-bio/cdcd/blob/master/training.ipynb) notebook. Here is a quick explanation of what it does
65+
66+
The first cell has to do with downloading the dataset and the tokenizer
67+
```python
68+
dataset = load_dataset("roneneldan/TinyStories")
69+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") # or any suitable tokenizer
70+
[... other code ...]
71+
```
72+
73+
The second cell has to do with defining the model
74+
```python
75+
76+
model=DiffusionModel(embed_dim,hidden_dim,qkv_dim,num_heads,cond_dim,n_blocks,tokenizer,p_self_cond,p_mask_cond,p_mask,prefix)
77+
78+
```
79+
80+
Third cell has to do with defining the optimizer
81+
```python
82+
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)
83+
lr_scheduler = [...]
84+
```
85+
86+
The fourth cell has the training loop
87+
```python
88+
for epoch in range(num_epochs):
89+
for i,tokens in enumerate(train_loader):
90+
91+
optimizer.zero_grad()
92+
tokens = batch['input_ids'].to(device)
93+
prediction=model(tokens)
94+
95+
loss = model.loss(prediction,tokens)
96+
loss.backward()
97+
optimizer.step()
98+
99+
# Log, print, or save as needed
100+
if i%schedule_update_frequency==0 and i!=0:
101+
model.noise_schedule.update_optimal_parameters()
102+
103+
if i%50==0 and i!=0:
104+
lr_scheduler.step()
105+
model.noise_schedule.plot_entropy_time_curve()
106+
```
107+
And you should the most recent datapoints along with the last best-fit for the Unormalized Cumulative Density Function $F(t)$
108+
![et_140M](https://github.com/markov-bio/cdcd/assets/47751420/8d08f943-c1b3-49da-a113-eb65f13e1cac)
109+
It represents the crossentropy loss of the model as a function of the noise $\sigma$ added. The more recent datapoints are colored darker.
110+
The blue curve represents the fit of $F(t)$ (learnt unormalized CDF).
111+
112+
The other curve that shows up is the one that represents how the best fit for $F(t)$ improves as the training progresses
113+
![curves_140M](https://github.com/markov-bio/cdcd/assets/47751420/6d87546e-cd87-42a8-b16b-a3cf02da7116)
114+
The more recent best-fitss are colored darker.
115+
As the curve shift to the right is idicates that it is learning how to denoise the signal better and better
116+
117+
### Comparison of the result with the CDCD paper
118+
Checking with a ruler it seems that the curve obtained in our experiment is pretty much identical to the one obtained by the autors in the figure 2 of the CDCD paper
119+
![plot](cdcd_noise_schedule.png)
120+
121+
# Pseudocode for Score interpolation
122+
Since in the original paper there is not any code explanation for the score interpolation here it is:
123+
124+
---
125+
126+
**Generation**$(D_{\theta}(x;t)$, $e_{j\in \{0,\ldots,V-1\}}$, $t_\textrm{max},t_\textrm{min}, N)$
127+
128+
129+
1. $S_i\gets \textrm {Uniform}(F(t_\textrm{max}),F(t_\textrm{min}), N)$ // Generate $N$ uniformly distributed samples $S_i$ between $F(t_\text{max})$
130+
2. $t_i \leftarrow F^{-1}(S_i)$ // Inverse transform sampling to get times
131+
3. $x_0 \sim \mathcal{N}(0, t_0^2 I)$ // Initialize $x_0$ with noise based on max time variance
132+
4. **For** $i \in \{0,\dots, N-1\}$ **do**:
133+
- $\hat x_0 \leftarrow D_{\theta}(x_i; t_i)$ // Apply model to estimate completely denoised image $\hat x_0$
134+
- $p_j(\hat x_0) \leftarrow \text{Softmax}(\hat x_0 \cdot e_j)$ // Softmax to get probabilities of embeddings
135+
- $\mathbf E_{p} [\hat x_0] \leftarrow \sum_{j}e_jp_j(\hat x_0)$ // Calculate expected embedding
136+
- $d_i \leftarrow \frac{x_i - \mathbf E_{p} [\hat x_0]}{t_i}$ // Compute derivative
137+
- $x_{i+1} \leftarrow x_i + (t_{i+1} - t_i) d_i$ // Euler step for next sample
138+
5. **Return** $x_N$ // return generated sample

cdcd_noise_schedule.png

180 KB
Loading

continous_diffusion/DiT_block.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from torch import nn, Tensor
3+
from torch.nn import functional as F
4+
5+
import einops
6+
7+
from .RoPe import RotaryEmbedding
8+
from .attention import SelfAttention
9+
10+
11+
class DiTBlock(nn.Module):
12+
def __init__(self, embed_dim:int,qkv_dim:int, num_heads:int, cond_dim:int, rope:RotaryEmbedding, max_len=5000):
13+
super().__init__()
14+
assert embed_dim>=2*num_heads and embed_dim%num_heads==0, 'the embed_dim must be a multiple of the number of heads'
15+
self.embed_dim=embed_dim
16+
self.qkv_dim=qkv_dim
17+
self.cond_dim=cond_dim
18+
19+
self.attention=SelfAttention(embed_dim, qkv_dim, num_heads, rope)
20+
self.rope=rope
21+
22+
self.make_scale_shift=MakeScaleShift(cond_dim, embed_dim)
23+
self.layernorm1=nn.LayerNorm(torch.broadcast_shapes((embed_dim,)))
24+
self.layernorm2=nn.LayerNorm(torch.broadcast_shapes((embed_dim,)))
25+
26+
self.feedforward = nn.Sequential(
27+
nn.Linear(embed_dim, 4 * embed_dim, bias = False),
28+
nn.GELU(),
29+
nn.Linear(4 * embed_dim, embed_dim, bias = False)
30+
)
31+
nn.init.zeros_(self.feedforward[-1].weight)
32+
# nn.init.zeros_(self.feedforward[-1].bias)
33+
34+
35+
def forward(self,x:Tensor,t_conditioning:Tensor, attn_mask:Tensor=None)->Tensor:
36+
"""
37+
Args:
38+
x (torch.Tensor): input tensor (b, l, c)
39+
conditioning (torch.Tensor): conditioning (l,).
40+
attn_mask (torch.Tensor): masks the [PAD] tokens (b, 1, l, l).
41+
42+
Returns:
43+
torch.Tensor: tensor x.shape
44+
"""
45+
46+
#here we create the scale-shift parameters from the conditioning
47+
alpha_1,beta_1,gamma_1,alpha_2,beta_2,gamma_2=self.make_scale_shift(t_conditioning)
48+
49+
res=x.clone()
50+
51+
x=self.layernorm1(x)
52+
x=apply_scale_shift(x,gamma_1,beta_1)
53+
x=self.attention(x,attn_mask)
54+
x=apply_scale_shift(x,alpha_1)
55+
56+
x=x+res
57+
res=x.clone()
58+
59+
x=self.layernorm2(x)
60+
x=apply_scale_shift(x,gamma_2,beta_2)
61+
x=self.feedforward(x)
62+
x=apply_scale_shift(x,alpha_2)
63+
64+
return x+res
65+
66+
67+
class MakeScaleShift(nn.Module):
68+
def __init__(self, cond_dim, embed_dim):
69+
super().__init__()
70+
71+
self.linear=nn.Linear(cond_dim, embed_dim*6)
72+
nn.init.zeros_(self.linear.weight)
73+
nn.init.zeros_(self.linear.bias)
74+
75+
def forward(self, conditioning:Tensor):
76+
assert conditioning.dim() == 2, "all of the cells must have the same conditioning"
77+
return self.linear(conditioning).chunk(6,dim=-1)
78+
79+
def apply_scale_shift(x, scale, shift:Tensor=None):
80+
81+
scale=scale+1
82+
x=einops.einsum(x,scale,'b ... c, b c -> b ... c')
83+
84+
if shift is not None:
85+
x=x+shift.unsqueeze(1)
86+
87+
return F.layer_norm(x, normalized_shape=(x.shape[-1],))
88+

continous_diffusion/RoPe.py

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
from torch import nn, Tensor
3+
import einops
4+
5+
class RotaryEmbedding(nn.Module):
6+
# adapted form
7+
# https://github.com/lucidrains/PaLM-rlhf-pytorch/blob/6b02ee329106baff78e293afa7d1d2e6dd4e5ca2/palm_rlhf_pytorch/palm.py#L69-L92
8+
def __init__(self, dim, scale_base = 512):
9+
super().__init__()
10+
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
11+
self.register_buffer("inv_freq", inv_freq)
12+
13+
14+
self.scale_base = scale_base
15+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
16+
self.register_buffer('scale', scale)
17+
18+
self.register_buffer("pos_emb", None, persistent=False)
19+
self.register_buffer("pos_emb_scale", None, persistent=False)
20+
21+
def make_rotary_embedding(self, seq_len):
22+
t = torch.arange(seq_len).type_as(self.inv_freq)
23+
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
24+
freqs = torch.cat((freqs, freqs), dim = -1)
25+
26+
power = (t - (seq_len // 2)) / self.scale_base
27+
scale = self.scale ** einops.rearrange(power, 'n -> n 1')
28+
scale = torch.cat((scale, scale), dim = -1)
29+
30+
return freqs, scale
31+
32+
def get_rotary_embedding(self,n):
33+
if (self.pos_emb is not None) and self.pos_emb.shape[-2] >= n:
34+
return self.pos_emb[:n], self.pos_emb_scale[:n]
35+
36+
pos_emb, scale = self.make_rotary_embedding(n)
37+
self.register_buffer("pos_emb", pos_emb, persistent=False)
38+
self.register_buffer("pos_emb_scale", scale, persistent=False)
39+
return pos_emb, scale
40+
41+
42+
def forward(self, q:Tensor, k:Tensor):
43+
44+
pos,scale=self.get_rotary_embedding(q.shape[-2])
45+
q= (q * pos.cos() + rotate_half(q) * pos.sin())*scale
46+
k= (k * pos.cos() + rotate_half(k) * pos.sin())/scale
47+
48+
return q,k
49+
50+
def rotate_half(x):
51+
x1, x2 = x.chunk(2, dim=-1)
52+
return torch.cat((-x2, x1), dim=-1)
53+

continous_diffusion/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
from .diffusion import Diffusion
3+
from .model import DiffusionTransformer
4+
from .loss import Loss
5+
from .embedding import Embedder
6+
from .scheduling import AdaptiveSchedule

continous_diffusion/attention.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
from torch import nn
3+
from torch.nn import functional as F
4+
import einops
5+
6+
from .RoPe import RotaryEmbedding
7+
8+
class SelfAttention(nn.Module):
9+
def __init__(self, embed_dim: int, qkv_dim: int, num_heads: int, rope: RotaryEmbedding):
10+
super().__init__()
11+
12+
self.embed_dim = embed_dim
13+
self.num_heads = num_heads
14+
15+
self.W_qkv = nn.Linear(embed_dim, 3 * qkv_dim, bias=False)
16+
self.W_out= nn.Linear(qkv_dim, embed_dim, bias=False)
17+
18+
#nn.init.zeros_(self.feedforward.weight)
19+
#nn.init.zeros_(self.feedforward.bias)
20+
21+
self.rope = rope
22+
23+
self.scale = nn.Parameter(torch.tensor(0.))
24+
25+
def forward(self, x, attn_mask=None):
26+
27+
x = self.W_qkv(x)
28+
x = einops.rearrange(x,'... l (h c) -> ... h l c', h=self.num_heads)
29+
q, k, v = x.chunk(3, dim=-1)
30+
31+
# Using QK-norm
32+
q = F.normalize(q, p=2, dim=-1) * torch.exp(self.scale)
33+
k = F.normalize(k, p=2, dim=-1)
34+
35+
q, k = self.rope(q, k)
36+
37+
x = F.scaled_dot_product_attention(q, k, v, scale=1, attn_mask=attn_mask)
38+
39+
x = einops.rearrange(x,'... h l c -> ... l (h c)')
40+
41+
return self.W_out(x)

0 commit comments

Comments
 (0)