Skip to content

Commit a261fba

Browse files
committed
Update README with commentary
1 parent fbaf2d3 commit a261fba

8 files changed

+221
-65
lines changed

README.md

Lines changed: 196 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,208 @@
1+
# Towards Monosemanticity: Decomposing Language Models With Dictionary Learning _(on your laptop! in 10 minutes or less!)_
12

2-
# Towards Monosemanticity: Decomposing Language Models With Dictionary Learning
3+
Hello! You are probably aware of this [very cool research from Anthropic](https://transformer-circuits.pub/2023/monosemantic-features/index.html)
4+
where they train an autoencoder to interpret the inner workings of a transformer.
35

4-
This repository reproduces results of [Anthropic's Sparse Dictionary Learning paper](https://transformer-circuits.pub/2023/monosemantic-features/). The codebase is quite rough, but the results are excellent. See the [feature interface](https://shehper.github.io/feature-interface/) to browse through the features learned by the sparse autoencoder. There are improvements to be made (see the [TODOs](#todos) section below), and I will work on them intermittently as I juggle things in life :)
6+
(If you are not aware, feel free to read it, it's quite fascinating)
57

6-
I trained a 1-layer transformer model from scratch using [nanoGPT](https://github.com/karpathy/nanoGPT) with $d_{\text{model}} = 128$. Then, I trained a sparse autoencoder with $4096$ features on its MLP activations as in [Anthropic's paper](https://transformer-circuits.pub/2023/monosemantic-features/). 93% of the autoencoder neurons were alive, only 5% of which were of ultra-low density. There are several interesting features. For example, there is [a feature for French language](https://shehper.github.io/feature-interface/?page=2011),
8+
I was curious how this technique scales when you don't use very much data or compute. Does it still work?
9+
At all?
710

8-
<p align="center">
9-
<img src="./assets/french.png" width="700" />
10-
</p>
11-
12-
a feature each for German, Japanese, and many other languages, as well many other interesting features:
13-
14-
- [A feature for German](https://shehper.github.io/feature-interface/?page=156)
15-
- [A feature for Scandinavian languages](https://shehper.github.io/feature-interface/?page=1634)
16-
- [A feature for Japanese](https://shehper.github.io/feature-interface/?page=1989)
17-
- [A feature for Hebrew](https://shehper.github.io/feature-interface/?page=2026)
18-
- [A feature for Cyrilic vowels](https://shehper.github.io/feature-interface/?page=3987)
19-
- [A feature for token "at" in words like "Croatian", "Scat", "Hayat", etc](https://shehper.github.io/feature-interface/?page=1662)
20-
- [A single token feature for "much"](https://shehper.github.io/feature-interface/?page=2760)
21-
- [A feature for sports leagues: NHL, NBA, etc](https://shehper.github.io/feature-interface/?page=379)
22-
- [A feature for Gregorian calendar dates](https://shehper.github.io/feature-interface/?page=344)
23-
- [A feature for "when"](https://shehper.github.io/feature-interface/?page=2022):
24-
- this feature particularly stands out because of the size of the mode around large activation values.
25-
- [A feature for "&"](https://shehper.github.io/feature-interface/?page=1916)
26-
- [A feature for ")"](https://shehper.github.io/feature-interface/?page=1917)
27-
- [A feature for "v" in URLs like "com/watch?v=SiN8](https://shehper.github.io/feature-interface/?page=27)
28-
- [A feature for programming code](https://shehper.github.io/feature-interface/?page=45)
29-
- [A feature for Donald Trump](https://shehper.github.io/feature-interface/?page=292)
30-
- [A feature for LaTeX](https://shehper.github.io/feature-interface/?page=538)
31-
32-
<!-- - [Bigram feature 1?](https://shehper.github.io/feature-interface/?page=446)
33-
[Bigram feature 2?](https://shehper.github.io/feature-interface/?page=482) -->
11+
So, I forked this lovely repo from @shehper (thank you shehper!), and made some modifications.
12+
As shehper notes in their original README, their results are excellent, but the code is a bit rough.
13+
14+
I cleaned up a code a bit, but since I don't have any GPUs, my results are worse. I still had a good time
15+
and found it interesting. If you'd like to reproduce my results in a few minutes on your
16+
(hopefully new-ish Apple Silicon) laptop, read on!
17+
18+
## Reproduction Instructions (+ commentary)
3419

35-
<!-- - [A feature for some negative words/news](https://shehper.github.io/feature-interface/?page=218) -->
20+
First, clone the repo. Then, run:
3621

37-
### Training Details
22+
```
23+
cd transformer/data/shakespeare_char
24+
python prepare.py
25+
```
26+
27+
This saves the encoded version of the dataset to disk, and trains a tokenizer.
28+
29+
_What? Training a tokenizer?_
30+
31+
That's right, reader -- we're adding additional complexity with dubious benefit, right out of the gate.
32+
33+
I didn't want to use the full gpt-2 vocabulary (trying to keep the model small), so I just used the
34+
famous shakespeare_char dataset, which tokenizes per character and has a vocab size less than 100.
35+
36+
Initial results with this were poor, and I wasn't able to derive any sort of clear link
37+
between specific neuron activations and semantic content understanding.
38+
I hypothesized that was partly due to the tokenization, so I added a custom tokenizer
39+
with a vocabulary of 1024.
40+
41+
This tokenizer was trained on a custom dataset of 1M characters of shakespeare_char and 1M characters
42+
of a Python dataset I found on HuggingFace. This is a notable detail: because of the low compute +
43+
low data regime I'm working in, my idea was to train a model on a relatively bi-modal dataset, and see
44+
if the neurons would nicely fall into two categories, either a "Shakespeare Neuron" or a "Python Neuron".
45+
(Spoiler: they kinda did!)
46+
47+
For context, the Anthropic paper has some great details on their training regime -- they train their
48+
tiny transformer on 100B tokens, and their autoencoder on 8B activation vectors. I am obviously quite a
49+
few orders of magnitude below that, sadly.
50+
51+
Okay, now that your dataset and tokenizer is ready to go, let's train some models!
52+
53+
```
54+
# need MPS_FALLBACK because nn.Dropout isn't supported, :(
55+
56+
# in the transformer/ directory
57+
PYTORCH_ENABLE_MPS_FALLBACK=1 python train_transformer.py \
58+
config/train_shakespeare_char.py \
59+
--device=mps \
60+
--max_iters=7500 \
61+
--lr_decay_iters=7500 \
62+
--n_embd=192 \
63+
--out_dir=0730-shakespeare-python-custom-tok \
64+
--batch_size=24 \
65+
--compile=True
66+
```
67+
68+
Here we train our tiny transformer. This is all ripped out of minGPT, thanks Karpathy!
69+
This takes 5 minutes or so on my MacBook. There a lot of hyperparams there that are ripe for tuning.
70+
71+
Once it's trained, it's good to check if it actually produces sensible outputs.
72+
73+
```
74+
# in the autoencoder/ directory
75+
python generate_tokens.py \
76+
--prompt 'def run' \ # or 'oh romeo'
77+
--gpt_ckpt_dir=0730-shakespeare-python-custom-tok
78+
```
79+
80+
When I run this, I get:
81+
82+
```
83+
oh romeo!
84+
85+
First Servingman:
86+
Yecond Servingman:
87+
What, had, are in Pomprecy thyself.
88+
89+
First Servingman:
90+
Good lady?
91+
First Servingman:
92+
Spirrairint sir:
93+
Thour, sir, thou art 'tis
94+
```
95+
96+
and:
97+
98+
```
99+
def run_path( artist):
100+
app.0
101+
sum = 0
102+
sum += generate_cart = 0
103+
sum(input("Facters)
104+
print(result)
105+
###
106+
"""
107+
def num_sum(ates, b):
108+
sum = 0
109+
a = 0
110+
while b < 0:
111+
sum += sum1
112+
while 1:
113+
min = b
114+
for print(sum + min(num) + b += 1
115+
else
116+
```
117+
118+
Which, hey, not too bad for a few minutes of local training! Vaguely reasonable Python and Shakespeare.
119+
120+
Now, it's time to train the autoencoder. First, prepare the dataset:
121+
122+
```
123+
python prepare_autoencoder_dataset.py \
124+
--num_contexts=7500 \
125+
--num_sampled_tokens=200 \
126+
--dataset=shakespeare_char \
127+
--gpt_ckpt_dir=0730-shakespeare-python-custom-tok
128+
```
129+
130+
This runs a bunch of forward passes on the trained transformer, and saves them to disk as `.pt` files for later training.
131+
This is done as a memory optimization, and arguably isn't needed at small scale training. The 7500 contexts is tuned such that it covers
132+
the entirety of the dataset that we saved to disk the previous `prepare.py` script. The Anthropic paper mentions that they see best results
133+
training the autoencoder without data re-sampling (aka, not repeating data), so we follow that here.
134+
135+
Let's train the autoencoder:
136+
137+
```
138+
python train_autoencoder.py \
139+
--device=mps \
140+
--l1_coeff=3e-7 \
141+
--learning_rate=3e-4 \
142+
--gpt_ckpt_dir=0730-shakespeare-python-custom-tok \
143+
--dataset=shakespeare_char \
144+
--batch_size=2048 \
145+
--resampling_interval=500 \
146+
--resampling_data_size=100 \
147+
--save_interval=1000 \
148+
--n_features=1536
149+
```
150+
151+
Anthropic's paper discusses a concept of "neuron resampling", where they revive dead neurons during training. Among other factors, a cause
152+
of dead neurons is training too long. Since I'm training for a very short period of time, I don't see any dead neurons, and that functionality
153+
wasn't needed -- but if you'd like to train longer, you should be aware of this.
154+
155+
Also, notice the `n_features=1536` -- this is a 2x multiple autoencoder. Not that the embedding dimension of the transformer we trained is 192.
156+
Since we're plucking the features from inside the MLP, which has is projected into 4x the embedding dimension (aka 768), and 1536 is twice that.
157+
158+
Anthropic mentions they train a family of models with multiples from 1x to 256x. For a low-compute reproduce, 2x seemed to work well. I tried
159+
a few other multiples and didn't see significant difference.
160+
161+
Note that training the autoencoder is quite fast. This is probably the weakest point of the pipeline, and more data would probably help. An
162+
obvious TODO would be to scale this up to ~1 hour of training or so, and see how things change.
163+
164+
Once this is done, it's time to actually inspect the features, by running repeated forward passes with certain neurons suppressed.
165+
166+
```
167+
python build_website.py \
168+
--device=mps \
169+
--dataset=shakespeare_char \
170+
--gpt_ckpt_dir=0730-shakespeare-python-custom-tok \
171+
--sae_ckpt_dir=2024-07-31-0936 \
172+
--num_contexts=100 \
173+
--num_phases=96
174+
```
175+
176+
## Results
177+
178+
We do see some differentiation between neurons! But, not a tremendous amount. The activations don't follow a nice power law as one might
179+
hope. I hypothesize this is due to lack of data.
180+
181+
My hope was that there would be clearly delineated "Python" vs "Shakespeare" neurons. This happened a bit, but not as much as I would have liked.
182+
One obvious thing to try would be train on multiple character sets, for example English and Japanese. This would create differences at the token
183+
level, that might allow for a more clear fragmentation of the internals.
184+
185+
Plenty of Neurons seem to be "majority Python", but few neurons seem to be "majority Shakespeare".
186+
187+
Here's an example of a Python neuron:
38188

39-
I used the "OpenWebText" dataset to train the transformer model, to generate the MLP activations dataset for the autoencoder, and to generate the feature interface visualizations. The transformer model had $d_{\text{model}}= 128$, $d_{\text{MLP}} = 512$, and $n_{\text{head}}= 4$. I trained this model for $2 \times 10^5$ iterations to roughly match the number of epochs with [Anthropic's training procedure](https://transformer-circuits.pub/2023/monosemantic-features#appendix-transformer).
40-
41-
I collected the dataset of 4B MLP activations by performing forward pass on 20M prompts (each of length 1024), keeping 200 activation vectors from each prompt. Next, I trained the autoencoder for approximately $5 \times 10^5$ training steps at batch size 8192 and learning rate $3 \times 10^{-4}$. I performed neuron resampling 4 times during training at training steps $2.5 \times i \times 10^4$ for $i=1, 2, 3, 4$. See a complete log of the training run on the [W&B page](https://wandb.ai/shehper/sparse-autoencoder-openwebtext-public/runs/vjbcwjsf?nw=nwusershehper). The L1-coefficient for this training run is $10^{-3}$. I selected the L1-coefficient and the learning rate by performing a grid search.
189+
<p align="center">
190+
<img src="./assets/python.png" width="700" />
191+
</p>
42192

43-
For the most part, I followed the training procedure described in the [appendix](https://transformer-circuits.pub/2023/monosemantic-features#appendix-autoencoder) of Anthropic's original paper. I did not follow the improvements they suggested in their [January](https://transformer-circuits.pub/2024/jan-update/index.html) and [February](https://transformer-circuits.pub/2024/feb-update/index.html) updates.
193+
And here's an example of a (rare) Shakespeare neuron:
44194

45-
### TODOs
46-
- Incorporate the effects of feature ablations in the feature interface.
47-
- Implement an interface to see "Feature Activations on Example Texts" as done by Anthropic [here](https://transformer-circuits.pub/2023/monosemantic-features/vis/a1-math.html).
48-
- Modify the code so that one can train a sparse autoencoder on activations of any MLP / attention layer.
195+
<p align="center">
196+
<img src="./assets/shakespeare.png" width="700" />
197+
</p>
49198

50-
### Related Work
51-
There are several other very interesting works on the web exploring sparse dictionary learning. Here is a small subset of them.
199+
## Future Work
52200

53-
- [Sparse Autoencoders Find Highly Interpretable Features in Language Models by Cunningham, et al.](https://arxiv.org/abs/2309.08600)
54-
- [Sparse Autoencoders Work on Attention Layer Outputs by Kissane, et al.](https://www.lesswrong.com/posts/DtdzGwFh9dCfsekZZ/sparse-autoencoders-work-on-attention-layer-outputs)
55-
- [Joseph Bloom's SAE codebase](https://github.com/jbloomAus/mats_sae_training) along with a blogpost on [trained SAEs for all residual stream layers of GPT-2 small](https://www.alignmentforum.org/posts/f9EgfLSurAiqRJySD/open-source-sparse-autoencoders-for-all-residual-stream)
56-
- [Neel Nanda's SAE codebase](https://github.com/neelnanda-io/1L-Sparse-Autoencoder) along with a [blogpost](https://www.lesswrong.com/posts/fKuugaxt2XLTkASkk/open-source-replication-and-commentary-on-anthropic-s)
57-
- [Callum McDougall's exercises on SAEs](https://github.com/callummcdougall/sae-exercises-mats/tree/main)
58-
- [SAE library by AI Safey Foundation](https://github.com/ai-safety-foundation/sparse_autoencoder)
201+
I'm really interested in the intersection of "low compute / efficient training" + "model interpretability". Anthropic mentions the challenges
202+
of scaling this approach up to large models in their paper, and they mention it more in their successful extraction of Claude 3 Sonnet
203+
features paper.
59204

205+
As models continue to get bigger, and open source continues to try to catch up, it seems valuable to have well established
206+
off-the-shelf techniques to decompose and interpret the inner workings of a transformer. Among other things, the steerability and application
207+
benefits are significant. The golden gate bridge feature demo was a fun example, but the tactical benefits and possibilities make the
208+
concept of a system prompt seem somewhat antiquated.

assets/french.png

-191 KB
Binary file not shown.

assets/german.png

-184 KB
Binary file not shown.

assets/python.png

281 KB
Loading

assets/scandinavian.png

-184 KB
Binary file not shown.

assets/shakespeare.png

351 KB
Loading

autoencoder/prepare_autoencoder_dataset.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
""""
2-
Prepares training dataset for our autoencoder.
2+
Prepares training dataset for our autoencoder.
33
Run on Macbook as
44
python -u prepare.py --num_contexts=5000 --num_sampled_tokens=16 --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32
55
"""
@@ -16,9 +16,9 @@
1616
gpt_ckpt_dir = 'out' # Model checkpoint directory
1717
# autoencoder data size
1818
num_contexts = int(2e6) # Number of context windows
19-
num_sampled_tokens = 200 # Tokens per context window
19+
num_sampled_tokens = 200 # Tokens per context window -- TODO: why isn't this just the same as block size?
2020
# system
21-
device = 'cpu'
21+
device = 'mps'
2222
num_partitions = 20 # Number of output files
2323
# reproducibility
2424
seed = 0
@@ -46,24 +46,24 @@
4646

4747
def compute_activations():
4848
start_time = time.time()
49-
gpt_batch_size = 500
49+
gpt_batch_size = 500 # ?
5050
n_batches = num_contexts // gpt_batch_size
5151

5252
for batch in range(n_batches):
5353
# Load batch and compute activations
54-
x, _ = resource_loader.get_text_batch(gpt_batch_size)
54+
x, _ = resource_loader.get_text_batch(gpt_batch_size) # (gpt_batch_size, block_size)
5555
_, _ = gpt(x) # Forward pass
56-
activations = gpt.mlp_activation_hooks[0] # Retrieve activations
56+
activations = gpt.mlp_activation_hooks[0] # (gpt_batch_size, block_size, 4 * n_embd) -- 4x because the inner layer is 4x the size of the input
5757

5858
# Clean up to save memory
5959
gpt.clear_mlp_activation_hooks()
6060

6161
# Process and store activations
62-
token_locs = torch.stack([torch.randperm(block_size)[:num_sampled_tokens] for _ in range(gpt_batch_size)])
63-
data = torch.gather(activations, 1, token_locs.unsqueeze(2).expand(-1, -1, activations.size(2))).view(-1, n_ffwd)
64-
data_storage[shuffled_indices[batch * gpt_batch_size * num_sampled_tokens : (batch + 1) * gpt_batch_size * num_sampled_tokens]] = (
65-
data
66-
)
62+
token_locs = torch.stack([torch.randperm(block_size)[:num_sampled_tokens] for _ in range(gpt_batch_size)]) # (gpt_batch_size, num_sampled_tokens)
63+
data = torch.gather(activations, 1, token_locs.unsqueeze(2).expand(-1, -1, activations.size(2))).view(-1, n_ffwd) # (gpt_batch_size * num_sampled_tokens, n_ffwd)
64+
data_storage[
65+
shuffled_indices[batch * gpt_batch_size * num_sampled_tokens : (batch + 1) * gpt_batch_size * num_sampled_tokens]
66+
] = data
6767

6868
print(
6969
f"Batch {batch}/{n_batches} processed in {(time.time() - start_time) / (batch + 1):.2f} seconds; "

autoencoder/train_autoencoder.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
python train.py --dataset=shakespeare_char --gpt_ckpt_dir=out_sc_1_2_32 --eval_iters=1 --eval_batch_size=16 --batch_size=128 --device=cpu --eval_interval=100 --n_features=1024 --resampling_interval=150
66
"""
77

8+
from tqdm import trange
89
import os
910
import torch
1011
import numpy as np
@@ -74,8 +75,7 @@
7475
start_time = time.time()
7576
num_steps = resourceloader.autoencoder_data_info["total_examples"] // batch_size
7677

77-
for step in range(min(num_steps, 2500)):
78-
78+
for step in trange(num_steps, desc="Training Autoencoder"):
7979
batch = resourceloader.get_autoencoder_data_batch(step, batch_size=batch_size)
8080
optimizer.zero_grad(set_to_none=True)
8181
autoencoder_output = autoencoder(batch) # f has shape (batch_size, n_features)
@@ -89,7 +89,7 @@
8989
if step % 1000 == 0:
9090
autoencoder.normalize_decoder_columns()
9191

92-
## ------------ perform neuron resampling ----------- ######
92+
###### ------------ perform neuron resampling ----------- ######
9393
# check if we should start investigating dead/alive neurons at this step
9494
# This is done at an odd multiple of resampling_interval // 2 in Anthropic's paper.
9595
if autoencoder.is_dead_neuron_investigation_step(step, resampling_interval, num_resamples):
@@ -110,7 +110,7 @@
110110
data=resourceloader.select_resampling_data(size=resampling_data_size), optimizer=optimizer, batch_size=batch_size
111111
)
112112

113-
### ------------ log info ----------- ######
113+
###### ------------ log info ----------- ######
114114
if (step % eval_interval == 0) or step == num_steps - 1:
115115
print(f'Entering evaluation mode at step = {step}')
116116
autoencoder.eval()
@@ -128,9 +128,7 @@
128128
feat_acts_count = torch.zeros(n_features, dtype=torch.float32)
129129

130130
# get batches of text data and evaluate the autoencoder on MLP activations
131-
for iter in range(eval_iters):
132-
if iter % 20 == 0:
133-
print(f'Performing evaluation at iterations # ({iter} - {min(iter+19, eval_iters)})/{eval_iters}')
131+
for iter in trange(eval_iters, desc="Evaluating Autoencoder"):
134132
x, y = resourceloader.get_text_batch(num_contexts=eval_batch_size)
135133

136134
_, nll_loss = gpt(x, y)
@@ -175,6 +173,15 @@
175173
'feature_density/num_alive_neurons': len(log_feat_acts_density),
176174
}
177175
)
176+
# Print log_dict in a readable format
177+
print("Evaluation Results:")
178+
print("-" * 40)
179+
for key, value in log_dict.items():
180+
if isinstance(value, float):
181+
print(f"{key:<35} {value:.6f}")
182+
else:
183+
print(f"{key:<35} {value}")
184+
print("-" * 40)
178185

179186
autoencoder.train()
180187
print(f'Exiting evaluation mode at step = {step}')

0 commit comments

Comments
 (0)