Skip to content

Commit 2c68ecc

Browse files
authored
Add training script for ICL experiment.
Add an example training script to reproduce the in-context learning experiment from the paper (see issue #4). An important detail is to set `--unembed_mask 0` (otherwise the model will be prevented from predicting the `unk`, which is used for this task). You may need to run the script with multiple seeds (e.g. 10) to get an initialization that learns to solve the task.
1 parent 0e375a5 commit 2c68ecc

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

scripts/induction.sh

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
3+
VOCAB_SIZE=10
4+
MIN_LENGTH=9
5+
MAX_LENGTH=9
6+
SEED=6
7+
8+
echo "SEED=${SEED}";
9+
10+
python src/run.py \
11+
--dataset "induction" \
12+
--vocab_size "${VOCAB_SIZE}" \
13+
--dataset_size 20000 \
14+
--min_length "${MIN_LENGTH}" \
15+
--max_length "${MAX_LENGTH}" \
16+
--n_epochs 500 \
17+
--batch_size 512 \
18+
--lr "5e-2" \
19+
--gumbel_samples 1 \
20+
--sample_fn "gumbel_soft" \
21+
--tau_init 3.0 \
22+
--tau_end 0.01 \
23+
--tau_schedule "geomspace" \
24+
--n_vars_cat 1 \
25+
--n_vars_num 1 \
26+
--n_layers 2 \
27+
--n_heads_cat 1 \
28+
--n_heads_num 0 \
29+
--n_cat_mlps 0 \
30+
--n_num_mlps 0 \
31+
--attention_type "cat" \
32+
--rel_pos_bias "fixed" \
33+
--one_hot_embed \
34+
--count_only \
35+
--selector_width 0 \
36+
--seed "${SEED}" \
37+
--unique 1 \
38+
--unembed_mask 0 \
39+
--autoregressive \
40+
--save \
41+
--save_code \
42+
--output_dir "output/induction/s${SEED}";

0 commit comments

Comments
 (0)