Skip to content

AI-Hypercomputer/torchprime

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torchprime

High-performance training for PyTorch on Cloud TPU



torchprime is a reference implementation for training PyTorch models on TPU. It is designed to showcase best practices for large-scale, high-performance model training using torch_xla (project), with minimal changes to model code. It aims to demystify training on XLA-based accelerators, providing clear patterns and best practices to help the PyTorch community unlock top performance and efficiency on Google Cloud TPUs.

torchprime is under active development, and we're eager for feedback and input from the PyTorch community.

Environment setup

For development and debugging purposes it is useful to run torchprime locally on a TPU VM. You can create a single-host TPU VM using this guide: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm Or you can create TPU VM from the Google Cloud Console for your cloud project. Spot quota is available for v5e and v6e chips in multiple regions: https://cloud.google.com/tpu/docs/regions-zones

Make sure that you are using the correct runtime when creating your VM: https://cloud.google.com/tpu/docs/runtimes#pytorch_and_jax

For example:

gcloud compute tpus tpu-vm create <tpu-name> \
  --zone=us-central1-a \
  --accelerator-type=v6e-4 \
  --version=v2-alpha-tpuv6e \
  --spot

Once the VM is created you can ssh into it: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-connect

gcloud compute tpus tpu-vm ssh <tpu-name> --zone=<zone>

Install torch_xla

Before installing torchprime, you will need to first install torch_xla following its respective project README. You need to install nightly version of PyTorch/XLA.

Test that the environment is correctly installed and configured. Start python interpreter and run following:

import torch_xla.core.xla_model as xm
print("XLA devices:", xm.get_xla_supported_devices())
print("Default XLA device type:", xm.xla_device_hw(xm.xla_device()))

Install torchprime

Make sure that pip and setuptools are up-to-date:

python -m pip install --upgrade pip
python -m pip install --upgrade setuptools==69.5.1
git clone https://github.com/AI-Hypercomputer/torchprime.git
cd torchprime
pip install -e '.[dev]'

Examples

Local training

Here is a simple example of training on a single TPU VM. Train Llama 3 8B using torch_xla:

export HF_TOKEN='...your huggingface token...'
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py

The first two training steps will take a while to compile. After that, the graphs will hit the compilation cache and you should see something like this:

...
[2025-04-29 06:58:36,445][__main__][INFO] - Num replicas: 1
[2025-04-29 06:58:36,447][__main__][INFO] - Starting training
[2025-04-29 06:58:36,448][__main__][INFO] -     Max step: 15
[2025-04-29 06:58:36,448][__main__][INFO] -     Global batch size: 4
[2025-04-29 07:01:16,240][__main__][INFO] - Epoch: 0, step: 0, loss: 12.5574, trace time: 155003.85 ms
[2025-04-29 07:04:24,182][__main__][INFO] - Epoch: 0, step: 10, loss: 9.7555, trace time: 1564.54 ms
...

Refer to README.md in torchprime/torch_xla_models for more details.

Configuring training

torchprime uses hydra to read configurations (e.g. model name, batch size) from the command line and .yaml files.

In the torch_xla_models directory, you'll find a configs/default.yaml. That specifies the default configuration for the trainer. You may override configs on the command line with a key=value syntax. For example, the following command will train Mixtral 8x7B with a global batch size of 256, and set the FSDP SPMD ICI mesh axis length to 64:

python3 torchprime/torch_xla_models/train.py \
    model=mixtral-8x7b \
    global_batch_size=256 \
    ici_mesh.fsdp=64

You may refer to the hydra docs for other ways to specify configs.

Distributed training

torchprime uses xpk as the standard path for iterating on distributed training code.

First teach torchprime about the XPK cluster it is using, the artifact storage location, etc. You only need to do this on first clone or when switching to a different topology or cluster. Example:

tp use \
    --cluster <XPK CLUSTER NAME> \
    --project my-gcp-project \
    --zone us-east5-b \
    --num-slices 1 \
    --tpu-type v6e-256 \
    --artifact-dir gs://bucket/dir

torchprime natively supports multi-slice or multi-pod training. --num-slices specifies the number of slices used by the workload. --tpu-type specifies the accelerator type in each slice. To do multi-pod training, simply specify a --tpu-type that is as big as a pod.

After configuring the cluster, prepend tp run to a particular Python file you would like to run remotely, including arguments, e.g.

# Train Llama 3.0 8B on 256 chips
tp run torchprime/torch_xla_models/train.py \
    model=llama-3-8b \
    global_batch_size=256 \
    ici_mesh.fsdp=256

tp run will broadcast the specified command to all VMs in the XPK cluster, which is the convention for running SPMD distributed workloads. See tp run --help for more advanced features.

The version of torch_xla used by tp run is specified in pyproject.toml.

Env vars passed to the workload

tp run will pick up these environment variables locally and proxy them to the distributed workload, if found:

Additional CLI arguments passed to the workload

Besides forwarding your command line arguments, tp run will add:

  • profile_dir=[...]: path to a profile directory accessible by the workload
  • output_dir=[...] path to a directory where the workload may store output artifacts such as metrics and checkpoints

Supported Models

Below are the status of various models. There are five stages for each model:

  1. TODO: We need to implement the model.
  2. Implemented: The model runs either a training or an inference step.
  3. Optimized: We found the best scaling configuration for the model on one or more hardware. One-off performance data is available.
  4. Convergence: We tested that the training loss converges to a reasonable value, or that the loss curve tracks an existing reference if exists.
  5. Production: Not only is the model optimized and converges, its performance is also continuously monitored. This is a good state for using the model in production.

All implemented models will at least have unit tests to verify basic numerical correctness, and the convergence verification stage serves as an additional correctness guarantee.

If a model is implemented, you'll also find a training recipe linked from the checkmark emoji in the table. If a model is optimized, you'll also find MFU numbers linked from the table. Note that a model may continue to receive ongoing optimization thereafter.

Model Implemented Optimized Converges
Llama 3.0 8B TODO
Llama 3.1 8B TODO
Llama 3.1 70B TODO
Llama 3.1 405B TODO
Llama 4 Scout TODO TODO TODO
Llama 4 Maverick TODO TODO TODO
Mixtral 8x7B TODO TODO
Mixtral 8x22B TODO TODO TODO
DeepSeek V3/R1 TODO TODO TODO
Stable Diffusion 2.0 TODO TODO TODO
Stable Diffusion 2.1 TODO TODO TODO

Structure

This repo will contain a set of reference models that we have optimized and runs well on TPU. The best performing scaling configuration (parallelism techniques, checkpointing, etc.) for a model on various hardwares will be provided for ease of reproducibility.

docs contains guides for optimizing performance and debugging issues.

torchprime/launcher contains scripts to train a model on a large TPU cluster.

torchprime/data contains dataset and data loading utilities.

torchprime/torch_xla_models contains model implementations using torch_xla.

torchprime/experimental/torchax_models contains model implementations using torchax.

Finally, each model may also provide a GPU "original" version that illustrates and attributes where this model code came from, if any. This also helps to showcase what changes we have done to make it performant on TPU. The original version is not expected to be run.

Contributing

Contributions are welcome! Please feel free to submit a pull request.

Refer to the contributor guide to get started.

License

This project is licensed under the New BSD License - see the LICENSE file for details.

For more information on PyTorch/XLA, visit the official documentation.

About

torchprime is a reference model implementation for PyTorch on TPU.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published