@@ -23,9 +23,9 @@ use crate::compute_cap::{
23
23
} ;
24
24
use crate :: models:: {
25
25
BertConfig , BertModel , Dense , DenseConfig , DenseLayer , DistilBertConfig , DistilBertModel ,
26
- GTEConfig , GTEModel , JinaBertModel , JinaCodeBertModel , MPNetConfig , MPNetModel , MistralConfig ,
27
- Model , ModernBertConfig , ModernBertModel , NomicBertModel , NomicConfig , Qwen2Config ,
28
- Qwen3Config , Qwen3Model ,
26
+ GTEConfig , GTEModel , Gemma3Config , Gemma3Model , JinaBertModel , JinaCodeBertModel , MPNetConfig ,
27
+ MPNetModel , MistralConfig , Model , ModernBertConfig , ModernBertModel , NomicBertModel ,
28
+ NomicConfig , Qwen2Config , Qwen3Config , Qwen3Model ,
29
29
} ;
30
30
#[ cfg( feature = "cuda" ) ]
31
31
use crate :: models:: {
@@ -95,6 +95,8 @@ enum Config {
95
95
Camembert ( BertConfig ) ,
96
96
#[ serde( rename( deserialize = "distilbert" ) ) ]
97
97
DistilBert ( DistilBertConfig ) ,
98
+ #[ serde( rename( deserialize = "gemma3_text" ) ) ]
99
+ Gemma3 ( Gemma3Config ) ,
98
100
#[ serde( alias = "new" ) ]
99
101
Gte ( GTEConfig ) ,
100
102
#[ serde( rename = "mpnet" ) ]
@@ -263,6 +265,16 @@ impl CandleBackend {
263
265
DistilBertModel :: load ( vb, & config, model_type) . s ( ) ?,
264
266
) )
265
267
}
268
+ ( Config :: Gemma3 ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
269
+ if dtype != DType :: F32 {
270
+ Err ( BackendError :: Start (
271
+ "Gemma3 is only supported in fp32 precision" . to_string ( ) ,
272
+ ) )
273
+ } else {
274
+ tracing:: info!( "Starting Gemma3 model on {:?}" , device) ;
275
+ Ok ( Box :: new ( Gemma3Model :: load ( vb, & config, model_type) . s ( ) ?) )
276
+ }
277
+ }
266
278
( Config :: Gte ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
267
279
tracing:: info!( "Starting GTE model on {:?}" , device) ;
268
280
Ok ( Box :: new ( GTEModel :: load ( vb, & config, model_type) . s ( ) ?) )
@@ -381,6 +393,17 @@ impl CandleBackend {
381
393
}
382
394
}
383
395
#[ cfg( feature = "cuda" ) ]
396
+ ( Config :: Gemma3 ( config) , Device :: Cuda ( _) ) => {
397
+ if dtype != DType :: F32 {
398
+ Err ( BackendError :: Start (
399
+ "Gemma3 is only supported in fp32 precision" . to_string ( ) ,
400
+ ) )
401
+ } else {
402
+ tracing:: info!( "Starting Gemma3 model on {:?}" , device) ;
403
+ Ok ( Box :: new ( Gemma3Model :: load ( vb, & config, model_type) . s ( ) ?) )
404
+ }
405
+ }
406
+ #[ cfg( feature = "cuda" ) ]
384
407
( Config :: Gte ( config) , Device :: Cuda ( _) ) => {
385
408
if dtype != DType :: F16
386
409
|| !cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
0 commit comments