Gemma is a family of open-weights Large Language Model (LLM) by Google DeepMind, based on Gemini research and technology.
This repository contain the implementation of the
gemma
PyPI package. A
JAX library to use and fine-tune Gemma.
For examples and uses-cases, see our documentation. Please report issues and feedback in our GitHub.
-
Install JAX for CPU, GPU or TPU. Follow instructions at the JAX website.
-
Run
pip install gemma
Here is a minimal example to have a multi-turn, multi-modal conversation with Gemma:
from gemma import gm
# Model and parameters
model = gm.nn.Gemma3_4B()
params = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT)
# Example of multi-turn conversation
sampler = gm.text.ChatSampler(
model=model,
params=params,
mult_turn=True,
)
prompt = """Which of the 2 images do you prefer ?
Image 1: <start_of_image>
Image 2: <start_of_image>
Write your answer as a poem."""
out0 = sampler.chat(prompt, images=[image1, image2])
out1 = sampler.chat('What about the other image ?')
Our documentation contain various Colabs and tutorial, including:
Additionally, our examples/ folder contain additional scripts to fine-tune and sample with Gemma.
- To use this library: Gemma documentation
- Technical reports for metrics and model capabilities:
- Other Gemma implementations and doc on the Gemma ecosystem
To download the model weights. See our documentation.
Gemma can run on a CPU, GPU and TPU. For GPU, we recommend a 8GB+ RAM on GPU for the 2B checkpoint and 24GB+ RAM on GPU for the 7B checkpoint.
This is not an official Google product.