Skip to content

Commit abf7d42

Browse files
authored
Add Gemma3 architecture (text-only) (#711)
1 parent b22dd5c commit abf7d42

File tree

8 files changed

+1162
-230
lines changed

8 files changed

+1162
-230
lines changed

backends/candle/src/layers/linear.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use serde::Deserialize;
55
#[derive(Debug, Deserialize, PartialEq, Clone)]
66
#[serde(rename_all = "lowercase")]
77
pub enum HiddenAct {
8+
#[serde(alias = "gelu_pytorch_tanh")]
89
Gelu,
910
Relu,
1011
Silu,

backends/candle/src/lib.rs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ use crate::compute_cap::{
2323
};
2424
use crate::models::{
2525
BertConfig, BertModel, Dense, DenseConfig, DenseLayer, DistilBertConfig, DistilBertModel,
26-
GTEConfig, GTEModel, JinaBertModel, JinaCodeBertModel, MPNetConfig, MPNetModel, MistralConfig,
27-
Model, ModernBertConfig, ModernBertModel, NomicBertModel, NomicConfig, Qwen2Config,
28-
Qwen3Config, Qwen3Model,
26+
GTEConfig, GTEModel, Gemma3Config, Gemma3Model, JinaBertModel, JinaCodeBertModel, MPNetConfig,
27+
MPNetModel, MistralConfig, Model, ModernBertConfig, ModernBertModel, NomicBertModel,
28+
NomicConfig, Qwen2Config, Qwen3Config, Qwen3Model,
2929
};
3030
#[cfg(feature = "cuda")]
3131
use crate::models::{
@@ -95,6 +95,8 @@ enum Config {
9595
Camembert(BertConfig),
9696
#[serde(rename(deserialize = "distilbert"))]
9797
DistilBert(DistilBertConfig),
98+
#[serde(rename(deserialize = "gemma3_text"))]
99+
Gemma3(Gemma3Config),
98100
#[serde(alias = "new")]
99101
Gte(GTEConfig),
100102
#[serde(rename = "mpnet")]
@@ -263,6 +265,16 @@ impl CandleBackend {
263265
DistilBertModel::load(vb, &config, model_type).s()?,
264266
))
265267
}
268+
(Config::Gemma3(config), Device::Cpu | Device::Metal(_)) => {
269+
if dtype != DType::F32 {
270+
Err(BackendError::Start(
271+
"Gemma3 is only supported in fp32 precision".to_string(),
272+
))
273+
} else {
274+
tracing::info!("Starting Gemma3 model on {:?}", device);
275+
Ok(Box::new(Gemma3Model::load(vb, &config, model_type).s()?))
276+
}
277+
}
266278
(Config::Gte(config), Device::Cpu | Device::Metal(_)) => {
267279
tracing::info!("Starting GTE model on {:?}", device);
268280
Ok(Box::new(GTEModel::load(vb, &config, model_type).s()?))
@@ -381,6 +393,17 @@ impl CandleBackend {
381393
}
382394
}
383395
#[cfg(feature = "cuda")]
396+
(Config::Gemma3(config), Device::Cuda(_)) => {
397+
if dtype != DType::F32 {
398+
Err(BackendError::Start(
399+
"Gemma3 is only supported in fp32 precision".to_string(),
400+
))
401+
} else {
402+
tracing::info!("Starting Gemma3 model on {:?}", device);
403+
Ok(Box::new(Gemma3Model::load(vb, &config, model_type).s()?))
404+
}
405+
}
406+
#[cfg(feature = "cuda")]
384407
(Config::Gte(config), Device::Cuda(_)) => {
385408
if dtype != DType::F16
386409
|| !cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))

backends/candle/src/models/flash_gte.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
use crate::flash_attn::flash_attn_varlen;
22
use crate::layers::{get_cos_sin, get_inv_freqs, LayerNorm, Linear};
3-
use crate::models::{GTEClassificationHead, GTEConfig, Model, PositionEmbeddingType, GTEMLP};
3+
use crate::models::gte::{GTEClassificationHead, GTEConfig, GTEMLP};
4+
use crate::models::{Model, PositionEmbeddingType};
5+
46
use candle::{DType, Device, IndexOp, Result, Tensor};
57
use candle_nn::{Embedding, Module, VarBuilder};
68
use candle_rotary::apply_rotary_inplace;

0 commit comments

Comments
 (0)