This is the official repo for our paper Generative Language Recontsurction from Brain Recordings accepted by Nature Communications Bioligy. Language generation from brain recordings is a novel approach that supports direct language generation with BCIs (brain-computer interfaces) without pre-defineng or pre-generating language candidates to select from.
We have provided an example dataset to facilitate the replication of experiments. To run the example dataset, you can go into the sub-directory language_generation/src and use the following command:
cd language_generation/src
# model training and evaluation (runing BrainLLM)
python main.py -task_name Pereira_example -cuda 0 -load_check_point False -model_name llama-7b -checkpoint_path example -batch_size 8 -lr 1e-4 -pos False -pretrain_lr 1e-3 -pretrain_epochs 10 -wandb none -mode all
# control evaluation (runing PerBrainLLM)
python main.py -task_name Pereira_example -cuda 0 -load_check_point False -model_name llama-7b -checkpoint_path example -batch_size 8 -lr 1e-4 -pos False -pretrain_lr 1e-3 -pretrain_epochs 10 -wandb none -input_method permutated -mode evaluate -output test_permutated
# control evaluation (runing LLM)
python main.py -task_name Pereira_example -cuda 0 -load_check_point False -model_name llama-7b -checkpoint_path example -batch_size 8 -lr 1e-4 -pos False -pretrain_lr 1e-3 -pretrain_epochs 10 -wandb none -input_method mask_input -mode evaluate -output test_nobrain
To run with slurm, you can also use the provided scripts in the sub-directory language_generation/scripts (remember to replace the name of conda environment and the path of the sub-directory language_generation/scripts according to your settings).
sh example.sh
To run with the datasets utilized in our paper, please download the dataset from Tsinghua Cloud and unzip it. Use the parameter -dataset_path to specify the path of your unzip dataset. For example, if you unzip the dataset into your home directory as ~/released/, then you can run the training and evaluation of BrainLLM and the participant 1 in Huth dataset using the following command:
python main.py -task_name Huth_1 -cuda 0 -load_check_point False -model_name llama-7b -checkpoint_path Huth_1 -batch_size 8 -lr 1e-4 -pos False -pretrain_lr 1e-3 -pretrain_epochs 10 -wandb none -mode all -dataset_path ../../dataset/ -pos True
To evaluate the model performance, you can refer to the code in language_generation/src/post_hoc_evaluate.py
In addition to the language completion task, our method also supports generating a complete piece of text based on brain signals spanning a few minutes. The relevant code can be found in the directory of end2end_generation/. The implementation of full story construction is based on Tang et al. (thanks for their code). To run this code, you also need to download some helpful files from their code, i.e., the data_lm directory and transform the vocabulary of their implementation into the vocabulary of Llama-2 or GPT-2 series models. Here is a example that generate the human semantics while they are perceiving story of "where there's smoke":
cd language_generation/src
# train BrainLLM with the spliting strategy that left out the story of "where there's smoke"
python main.py -task_name Huth_1 -cuda 0 -load_check_point False -model_name gpt2-xl -checkpoint_path Huth_1_gpt2-xl -batch_size 8 -lr 1e-4 -pos False -pretrain_lr 1e-3 -pretrain_epochs 0 -wandb none -mode all -dataset_path ../../dataset/ -pos True -data_spliting end2end
cd ../end2end_generation/src
# run inference for full story construction
python main.py -task_name Huth_1 -cuda 0 -load_check_point False -model_name gpt2-xl -checkpoint_path Huth_1_gpt2-xl -wandb none -mode evaluate -pos True -data_spliting end2end -mode end2end -use_bad_words_ids False -ncontext 10 -gcontext 10 -length_penalty 0.3 -beam_width 3 -extensions 3
# run evaluation with Huth's metrics
python evaluate.py -dir Huth_1
This repo is developed with PyTorch. It can be installed manually according to the requirement of platform-specific custom configuration. The recommended commands for installation are:
# XX.X is a placeholder for cudatoolkit version. It should be specified according to your environment
conda install pytorch torchvision torchaudio cudatoolkit=XX.X -c pytorch
In our experiment, we use torch verison 2.0.1 and cuda verison 11.7. In addition to PyTorch, we adopt several publicly available packages, which can be installed by
pip install -r requirements.txt
Note: Llama-7b may produce NaNs during half-precision training. If you encounter this issue, you can refer to this: huggingface/transformers#25065.
To train the model, you need to special the parameter -mode as training (only training) or all (training and evaluation). You can specify several hyper parameters according to your requirement, the default parameters for Pereira's dataset, Huth's dataset, and Narratives dataset are provided in language_generation/scripts/example.sh, language_generation/scripts/huth.sh, and language_generation/scripts/narratives.sh, respectively. The meaning of hyper parameters are listed below:
Parameter | Meaning |
---|---|
model_name | the selected LLM, choose from {gpt2,gpt2-medium,gpt2-large,gpt2-xl,llama-2} |
method | only supported decoding in the released verison |
task_name | {dataset_name}_{participant_name}, dataset_name selected from {Pereira,Huth,Narratives} |
test_trail_ids | specify the range of test dataset, view the dict dataset2agrs in language_generation/src/config.py for default setting |
valid_trail_ids | specify the range of validation dataset, view the dict dataset2agrs in language_generation/src/config.py for default setting |
random_number | for cross-validation evaluation, cooperate with parameter test_trail_ids and valid_trail_ids |
batch_size | set as 8 in our experiment |
fmri_pca | how to do data dimensionality reduction, default is True |
cuda | specify the device number |
layer | not used in the released verison |
num_epochs | specify the maximum number of training epochs |
lr | learning rate, set as 1e-4 in our experiment |
dropout | dropout rate for brain decoder |
checkpoint_path | path of training checkpoint for saving and downloading |
load_check_point | whether to load existing checkpoint |
enable_grad | whether to allow the parameter in LLM updated or not |
mode | train: only training and evaluate in the validation set; evaluate: evaluate in the test set; all: train and evaluate |
additional_loss | training with additional loss, not used in the released verison |
fake_input | training with fake input, not used in the released verison |
add_end | not used in the released verison |
context | whether to discard data sample without any text prompt or not |
roi_selected | roi-based experiment, not used in the released verison |
project_name | specify the project name for wandb |
noise_ratio | not used in the released verison |
wandb | specify how to sync the experimental in wandb, selected from {online, offline, none} |
generation_method | generation method for the LLM, selected from {greeddy, beam} |
pos | specify whether to use position embedding in the brain decoder |
output | specify whether to use position embedding in the brain decoder |
data_spliting | specify how to split the dataset, selected from {random, cross_story}, default is random |
brain_model | the based model for the brain decoder, selected from {mlp,rnn,linear,big_mlp,multi_mlp} |
weight_decay | weight decay |
l2 | weight for l2 regularized loss |
num_layers | number of layers in the brain decoder |
evaluate_log | whether to evaluate in the test set for model in each training epoch |
normalized | whether to normalize the input |
activation | activation function, selected from {relu,sigmoid,tanh,relu6} |
pretrain_epochs | number of epochs in warm up step |
pretrain_lr | learning rate in warm up step |
data_size | maximum training data samples |
results_path | path to save model results |
dataset_path | path to the downloaded dataset |
shuffle_times | permutation times for PerBrainLLM |
To evaluate the model with different prompt input, i.e., BrainLLM, PerBrainLLM, and LLM, you can specify the parameter -input_method as normal, permutated, without_brain, respectively. To test the model performance without any text prompt, you should train and evaluate the model while setting -input_method as without_text.
After that, you can get output files for different prompt inputs. Then, you can evaluate their performance by runing the python script language_generation/src/post_hoc_evaluatoion.py with the path of output files specified. Refer to language_generation/src/post_hoc_evaluatoion.py for example usage:
python language_generation/src/post_hoc_evaluatoion.py
We test our approach on three public fMRI datasets: Pereira's dataset, Huth's dataset, and Narratives dataset. The brief introduction, ethical information, statistics, and useage details of these datasets are provied in our paper. A preprocessed verison dataset is released in Tsinghua Cloud, where the sub-directory of Pereira, Huth, and Narratives contain the preprocessed data for each participant and story in Pereira's dataset, Huth's dataset, and Narratives dataset, respectively.
This is the overall experimental results in terms of language similarity metrics. Refer to our paper for the explaination of metrics and more analyses.
Dataset | Model | Bleu-1(↑) | ROUGE-1(↑) | ROUGE-L(↑) | WER(↓) |
---|---|---|---|---|---|
Pereira’s | BrainLLM | 0.3432 | 0.2987 | 0.2878 | 0.7576 |
PerBrainLLM | 0.3269 | 0.2815 | 0.2751 | 0.7783 | |
StdLLM | 0.2415 | 0.2133 | 0.2096 | 0.8349 | |
Huth’s | BrainLLM | 0.1899 | 0.1780 | 0.1709 | 0.8946 |
PerBrainLLM | 0.1668 | 0.1536 | 0.1474 | 0.9109 | |
StdLLM | 0.1500 | 0.1360 | 0.1310 | 0.9200 | |
Narratives | BrainLLM | 0.1375 | 0.1301 | 0.1209 | 0.9239 |
PerBrainLLM | 0.1269 | 0.1211 | 0.1105 | 0.9311 | |
StdLLM | 0.0953 | 0.0858 | 0.0829 | 0.9485 |
If you find our work helpful, please consider citing us:
@article{ye2023language,
title={Language Generation from Brain Recordings},
author={Ye, Ziyi and Ai, Qingyao and Liu, Yiqun and Zhang, Min and Lioma, Christina and Ruotsalo, Tuukka},
journal={arXiv preprint arXiv:2311.09889},
year={2023}
}