Mario Brain trains an AI agent to play Super Mario Bros using deep reinforcement learning.
Clone the repository and create a virtual environment, and install the required packages.
git clone --recurse-submodules https://github.com/akanto/mario-brain.git
cd mario-brain
python3 -m venv .venv
source .venv/bin/activate
Install requirements if you have MacOS wih mps:
pip install -r requirements.txt
Install reqirements if you have Nvidia GPU (e.g with CUDA 12.6):
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu126
The models
directory is submodule and it is linked to mario-rl-model repository on Hugging Face. The git clone --recurse-submodules https://github.com/akanto/mario-brain.git
command will automatically clone the model repository as well.
Train the agent from scratch using the following command, the trained model will be saved in the models/
directory:
python -m mario_brain.train
Or if you wwish to launch in the background, you can use the following command:
nohup python -m mario_brain.train --parallel 4 --timesteps 50_000_000 > train.log 2>&1 &
Evaluate the trained model to see how well it performs, it will load the model from the models/
directory. The evaluation also renders the gameplay, so you can watch the AI play:
python -m mario_brain.evaluate
If you want to see some gamplay without AI, then you can run the random play script:
python -m mario_brain.random_play
If you want to play the game yourself, you can use the human play script. It will allow you to control the game using your keyboard:
python -m mario_brain.human_play
Benchmark contains a few scripts to test PyTorch and Gymasium performance on your machine. It does not provied any useful information about the training process, but it can be used to test the performance of your machine or test wether cuda or mps is working properly.
python -m mario_brain.benchmark
Training logs and metrics are stored in the ../logs/
directory. Launch TensorBoard to monitor progress:
tensorboard --logdir logs/PPO_1
Open the link in your browser (http://localhost:6006/) to view real-time metrics.
This was inspired by the following resources:
Some of the libraries used in this project are not compatible with the latest versions of NumPy and OpenAI's Gymnasium, therfore those libraries have been forked and the git repos were added to the requirements.txt
.
This project is licensed under the MIT License - see the LICENSE file for details.