@@ -11,17 +11,17 @@ use crate::compute_cap::{
11
11
compatible_compute_cap, get_compile_compute_cap, get_runtime_compute_cap,
12
12
} ;
13
13
use crate :: models:: {
14
- BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaConfig , JinaBertModel , JinaCodeConfig , JinaCodeBertModel ,
14
+ BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15
15
Model , NomicBertModel , NomicConfig ,
16
16
} ;
17
17
#[ cfg( feature = "cuda" ) ]
18
18
use crate :: models:: {
19
- FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel , FlashNomicBertModel ,
19
+ FlashBertModel , FlashDistilBertModel , FlashJinaBertModel , FlashJinaCodeBertModel ,
20
+ FlashNomicBertModel ,
20
21
} ;
21
22
use anyhow:: Context ;
22
23
use candle:: { DType , Device } ;
23
24
use candle_nn:: VarBuilder ;
24
- use models:: BertConfig ;
25
25
use nohash_hasher:: BuildNoHashHasher ;
26
26
use serde:: Deserialize ;
27
27
use std:: collections:: HashMap ;
@@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
30
30
Backend , BackendError , Batch , Embedding , Embeddings , ModelType , Predictions ,
31
31
} ;
32
32
33
+ /// This enum is needed to be able to differentiate between jina models that also use
34
+ /// the `bert` model type and valid Bert models.
35
+ /// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
36
+ /// run but is still better than the other options...
37
+ #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
38
+ #[ serde( tag = "_name_or_path" ) ]
39
+ pub enum BertConfigWrapper {
40
+ #[ serde( rename = "jinaai/jina-bert-implementation" ) ]
41
+ JinaBert ( BertConfig ) ,
42
+ #[ serde( rename = "jinaai/jina-bert-v2-qk-post-norm" ) ]
43
+ JinaCodeBert ( BertConfig ) ,
44
+ #[ serde( untagged) ]
45
+ Bert ( BertConfig ) ,
46
+ }
47
+
33
48
#[ derive( Deserialize ) ]
34
49
#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
35
50
enum Config {
36
- Bert ( BertConfig ) ,
51
+ Bert ( BertConfigWrapper ) ,
37
52
XlmRoberta ( BertConfig ) ,
38
53
Camembert ( BertConfig ) ,
39
54
Roberta ( BertConfig ) ,
40
- #[ serde( rename( deserialize = "jina_bert" ) ) ]
41
- JinaBert ( JinaConfig ) ,
42
- #[ serde( rename( deserialize = "jina_code_bert" ) ) ]
43
- JinaCodeBert ( JinaCodeConfig ) ,
44
55
#[ serde( rename( deserialize = "distilbert" ) ) ]
45
56
DistilBert ( DistilBertConfig ) ,
46
57
#[ serde( rename( deserialize = "nomic_bert" ) ) ]
@@ -76,7 +87,7 @@ impl CandleBackend {
76
87
"Runtime compute cap {} is not compatible with compile time compute cap {}" ,
77
88
get_runtime_compute_cap( ) . unwrap( ) ,
78
89
get_compile_compute_cap( ) . unwrap( )
79
- ) ) )
90
+ ) ) ) ;
80
91
}
81
92
Err ( err) => {
82
93
tracing:: warn!( "Could not find a compatible CUDA device on host: {err:?}" ) ;
@@ -123,18 +134,22 @@ impl CandleBackend {
123
134
( _, Device :: Cuda ( _) ) => Err ( BackendError :: Start (
124
135
"`cuda` feature is not enabled" . to_string ( ) ,
125
136
) ) ,
126
- ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
127
- tracing:: info!( "Starting Bert model on {:?}" , device) ;
128
- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
129
- }
130
- ( Config :: JinaBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
131
- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
132
- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
133
- }
134
- ( Config :: JinaCodeBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
135
- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
136
- Ok ( Box :: new ( JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
137
- }
137
+ ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => match config {
138
+ BertConfigWrapper :: JinaBert ( config) => {
139
+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
140
+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
141
+ }
142
+ BertConfigWrapper :: JinaCodeBert ( config) => {
143
+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
144
+ Ok ( Box :: new (
145
+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
146
+ ) )
147
+ }
148
+ BertConfigWrapper :: Bert ( config) => {
149
+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
150
+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
151
+ }
152
+ } ,
138
153
(
139
154
Config :: XlmRoberta ( config) | Config :: Camembert ( config) | Config :: Roberta ( config) ,
140
155
Device :: Cpu | Device :: Metal ( _) ,
@@ -158,48 +173,45 @@ impl CandleBackend {
158
173
( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
159
174
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
160
175
&& dtype == DType :: F16
161
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
162
176
// Allow disabling because of flash attention v1 precision problems
163
177
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
164
178
&& & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
165
179
{
166
- if config. position_embedding_type == PositionEmbeddingType :: Alibi {
167
- tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
168
- Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
169
- } else {
170
- tracing:: info!( "Starting Bert model on {:?}" , device) ;
171
- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
180
+ match config {
181
+ BertConfigWrapper :: JinaBert ( config) => {
182
+ tracing:: info!( "Starting FlashJinaBert model on {:?}" , device) ;
183
+ Ok ( Box :: new (
184
+ FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
185
+ ) )
186
+ }
187
+ BertConfigWrapper :: JinaCodeBert ( config) => {
188
+ tracing:: info!( "Starting FlashJinaCodeBert model on {:?}" , device) ;
189
+ Ok ( Box :: new (
190
+ FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
191
+ ) )
192
+ }
193
+ BertConfigWrapper :: Bert ( config) => {
194
+ tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
195
+ Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
196
+ }
172
197
}
173
- }
174
- #[ cfg( feature = "cuda" ) ]
175
- ( Config :: JinaBert ( config) , Device :: Cuda ( _) ) => {
176
- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
177
- && dtype == DType :: F16
178
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
179
- // Allow disabling because of flash attention v1 precision problems
180
- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
181
- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
182
- {
183
- tracing:: info!( "Starting FlashJinaBertModel model on {:?}" , device) ;
184
- Ok ( Box :: new ( FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?, ) )
185
198
} else {
186
- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
187
- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
188
- }
189
- #[ cfg( feature = "cuda" ) ]
190
- ( Config :: JinaCodeBert ( config) , Device :: Cuda ( _) ) => {
191
- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
192
- && dtype == DType :: F16
193
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
194
- // Allow disabling because of flash attention v1 precision problems
195
- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
196
- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
197
- {
198
- tracing:: info!( "Starting FlashJinaCodeBertModel model on {:?}" , device) ;
199
- Ok ( Box :: new ( FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?, ) )
200
- } else {
201
- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
202
- Ok ( Box :: new ( JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
199
+ match config {
200
+ BertConfigWrapper :: JinaBert ( config) => {
201
+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
202
+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
203
+ }
204
+ BertConfigWrapper :: JinaCodeBert ( config) => {
205
+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
206
+ Ok ( Box :: new (
207
+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
208
+ ) )
209
+ }
210
+ BertConfigWrapper :: Bert ( config) => {
211
+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
212
+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
213
+ }
214
+ }
203
215
}
204
216
}
205
217
#[ cfg( feature = "cuda" ) ]
@@ -209,7 +221,6 @@ impl CandleBackend {
209
221
) => {
210
222
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
211
223
&& dtype == DType :: F16
212
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
213
224
// Allow disabling because of flash attention v1 precision problems
214
225
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
215
226
&& & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
0 commit comments