Skip to content

Commit bfd0ffd

Browse files
committed
update
1 parent 8bfafcb commit bfd0ffd

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Arguments explanation:
5757
- ```--resume_checkpoint```: if not none, restore this checkpoint and continue training
5858
- ```--vocab```: the tokenizer is initialized using bert or load your own preprocessed vocab dictionary (e.g. using BPE)
5959

60-
It will take around 2 days to train a __*DiffuSeq*__ model on 2 NVIDIA A100 GPUs for QG and QQP, and the training steps should be increased accordingly along with the size of the training set. To reproduce the results of Table 1 in our paper, we suggest the following configuration for each dataset when training.
60+
It will take 2 more days to train a __*DiffuSeq*__ model on 4 NVIDIA A100 80G GPUs for QG and QQP, and the training steps should be increased accordingly along with the size of the training set. To reproduce the results of Table 1 in our paper, we suggest the following configuration for each dataset when training.
6161

6262
```
6363
python -m torch.distributed.launch --nproc_per_node=4 --master_port=12233 --use_env run_train.py --diff_steps 2000 --lr 0.0001 --learning_steps 50000 --save_interval 10000 --seed 102 --noise_schedule sqrt --hidden_dim 128 --bsz 2048 --dataset qqp --data_dir {datasets/QQP} --vocab bert --seq_len 128 --schedule_sampler lossaware --notes qqp
@@ -86,6 +86,7 @@ python eval_seq2seq.py --folder ../{your-path-to-outputs} --mbr
8686
```
8787
Note: if you want to use this evaluation script for output files from other models, please make sure the same line from these output files refers to the same piece of data. Otherwise the diversity score could be incorrect.
8888

89+
> Update 28 Nov 2022: We prepare the checkpoint and sampling results of 10 seeds for QQP dataset in this [link](https://drive.google.com/drive/folders/1vnhJIUqPQva_x_sH2h5a0moCc1NYmEpr?usp=sharing).
8990
9091
Welcome to discuss if you have any questions.
9192

sample_seq2seq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def main():
143143
diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
144144
)
145145

146-
sample_shape = (batch.shape[0], args.seq_len, args.hidden_dim)
146+
sample_shape = (x_start.shape[0], args.seq_len, args.hidden_dim)
147147

148148
samples = sample_fn(
149149
model,

0 commit comments

Comments
 (0)