@@ -12,7 +12,7 @@ use crate::compute_cap::{
12
12
} ;
13
13
use crate :: models:: {
14
14
BertConfig , BertModel , DistilBertConfig , DistilBertModel , JinaBertModel , JinaCodeBertModel ,
15
- JinaCodeConfig , JinaConfig , Model , NomicBertModel , NomicConfig ,
15
+ Model , NomicBertModel , NomicConfig ,
16
16
} ;
17
17
#[ cfg( feature = "cuda" ) ]
18
18
use crate :: models:: {
@@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
30
30
Backend , BackendError , Batch , Embedding , Embeddings , ModelType , Predictions ,
31
31
} ;
32
32
33
+ /// This enum is needed to be able to differentiate between jina models that also use
34
+ /// the `bert` model type and valid Bert models.
35
+ /// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
36
+ /// run but is still better than the other options...
37
+ #[ derive( Debug , Clone , PartialEq , Deserialize ) ]
38
+ #[ serde( tag = "_name_or_path" ) ]
39
+ pub enum BertConfigWrapper {
40
+ #[ serde( rename = "jinaai/jina-bert-implementation" ) ]
41
+ JinaBert ( BertConfig ) ,
42
+ #[ serde( rename = "jinaai/jina-bert-v2-qk-post-norm" ) ]
43
+ JinaCodeBert ( BertConfig ) ,
44
+ #[ serde( untagged) ]
45
+ Bert ( BertConfig ) ,
46
+ }
47
+
33
48
#[ derive( Deserialize ) ]
34
49
#[ serde( tag = "model_type" , rename_all = "kebab-case" ) ]
35
50
enum Config {
36
- Bert ( BertConfig ) ,
51
+ Bert ( BertConfigWrapper ) ,
37
52
XlmRoberta ( BertConfig ) ,
38
53
Camembert ( BertConfig ) ,
39
54
Roberta ( BertConfig ) ,
40
- #[ serde( rename( deserialize = "jina_bert" ) ) ]
41
- JinaBert ( JinaConfig ) ,
42
- #[ serde( rename( deserialize = "jina_code_bert" ) ) ]
43
- JinaCodeBert ( JinaCodeConfig ) ,
44
55
#[ serde( rename( deserialize = "distilbert" ) ) ]
45
56
DistilBert ( DistilBertConfig ) ,
46
57
#[ serde( rename( deserialize = "nomic_bert" ) ) ]
@@ -76,7 +87,7 @@ impl CandleBackend {
76
87
"Runtime compute cap {} is not compatible with compile time compute cap {}" ,
77
88
get_runtime_compute_cap( ) . unwrap( ) ,
78
89
get_compile_compute_cap( ) . unwrap( )
79
- ) ) )
90
+ ) ) ) ;
80
91
}
81
92
Err ( err) => {
82
93
tracing:: warn!( "Could not find a compatible CUDA device on host: {err:?}" ) ;
@@ -123,20 +134,22 @@ impl CandleBackend {
123
134
( _, Device :: Cuda ( _) ) => Err ( BackendError :: Start (
124
135
"`cuda` feature is not enabled" . to_string ( ) ,
125
136
) ) ,
126
- ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
127
- tracing:: info!( "Starting Bert model on {:?}" , device) ;
128
- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
129
- }
130
- ( Config :: JinaBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
131
- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
132
- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
133
- }
134
- ( Config :: JinaCodeBert ( config) , Device :: Cpu | Device :: Metal ( _) ) => {
135
- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
136
- Ok ( Box :: new (
137
- JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
138
- ) )
139
- }
137
+ ( Config :: Bert ( config) , Device :: Cpu | Device :: Metal ( _) ) => match config {
138
+ BertConfigWrapper :: JinaBert ( config) => {
139
+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
140
+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
141
+ }
142
+ BertConfigWrapper :: JinaCodeBert ( config) => {
143
+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
144
+ Ok ( Box :: new (
145
+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
146
+ ) )
147
+ }
148
+ BertConfigWrapper :: Bert ( config) => {
149
+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
150
+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
151
+ }
152
+ } ,
140
153
(
141
154
Config :: XlmRoberta ( config) | Config :: Camembert ( config) | Config :: Roberta ( config) ,
142
155
Device :: Cpu | Device :: Metal ( _) ,
@@ -160,56 +173,45 @@ impl CandleBackend {
160
173
( Config :: Bert ( config) , Device :: Cuda ( _) ) => {
161
174
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
162
175
&& dtype == DType :: F16
163
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
164
176
// Allow disabling because of flash attention v1 precision problems
165
177
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
166
178
&& & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
167
179
{
168
- if config. position_embedding_type == PositionEmbeddingType :: Alibi {
169
- tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
170
- Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
171
- } else {
172
- tracing:: info!( "Starting Bert model on {:?}" , device) ;
173
- Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
180
+ match config {
181
+ BertConfigWrapper :: JinaBert ( config) => {
182
+ tracing:: info!( "Starting FlashJinaBert model on {:?}" , device) ;
183
+ Ok ( Box :: new (
184
+ FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
185
+ ) )
186
+ }
187
+ BertConfigWrapper :: JinaCodeBert ( config) => {
188
+ tracing:: info!( "Starting FlashJinaCodeBert model on {:?}" , device) ;
189
+ Ok ( Box :: new (
190
+ FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
191
+ ) )
192
+ }
193
+ BertConfigWrapper :: Bert ( config) => {
194
+ tracing:: info!( "Starting FlashBert model on {:?}" , device) ;
195
+ Ok ( Box :: new ( FlashBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
196
+ }
174
197
}
175
- }
176
- }
177
- #[ cfg( feature = "cuda" ) ]
178
- ( Config :: JinaBert ( config) , Device :: Cuda ( _) ) => {
179
- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
180
- && dtype == DType :: F16
181
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
182
- // Allow disabling because of flash attention v1 precision problems
183
- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
184
- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
185
- {
186
- tracing:: info!( "Starting FlashJinaBertModel model on {:?}" , device) ;
187
- Ok ( Box :: new (
188
- FlashJinaBertModel :: load ( vb, & config, model_type) . s ( ) ?,
189
- ) )
190
- } else {
191
- tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
192
- Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
193
- }
194
- }
195
- #[ cfg( feature = "cuda" ) ]
196
- ( Config :: JinaCodeBert ( config) , Device :: Cuda ( _) ) => {
197
- if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
198
- && dtype == DType :: F16
199
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
200
- // Allow disabling because of flash attention v1 precision problems
201
- // See: https://github.com/huggingface/text-embeddings-inference/issues/37
202
- && & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
203
- {
204
- tracing:: info!( "Starting FlashJinaCodeBertModel model on {:?}" , device) ;
205
- Ok ( Box :: new (
206
- FlashJinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
207
- ) )
208
198
} else {
209
- tracing:: info!( "Starting JinaCodeBertModel model on {:?}" , device) ;
210
- Ok ( Box :: new (
211
- JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
212
- ) )
199
+ match config {
200
+ BertConfigWrapper :: JinaBert ( config) => {
201
+ tracing:: info!( "Starting JinaBertModel model on {:?}" , device) ;
202
+ Ok ( Box :: new ( JinaBertModel :: load ( vb, & config, model_type) . s ( ) ?) )
203
+ }
204
+ BertConfigWrapper :: JinaCodeBert ( config) => {
205
+ tracing:: info!( "Starting JinaCodeBert model on {:?}" , device) ;
206
+ Ok ( Box :: new (
207
+ JinaCodeBertModel :: load ( vb, & config, model_type) . s ( ) ?,
208
+ ) )
209
+ }
210
+ BertConfigWrapper :: Bert ( config) => {
211
+ tracing:: info!( "Starting Bert model on {:?}" , device) ;
212
+ Ok ( Box :: new ( BertModel :: load ( vb, & config, model_type) . s ( ) ?) )
213
+ }
214
+ }
213
215
}
214
216
}
215
217
#[ cfg( feature = "cuda" ) ]
@@ -219,7 +221,6 @@ impl CandleBackend {
219
221
) => {
220
222
if cfg ! ( any( feature = "flash-attn" , feature = "flash-attn-v1" ) )
221
223
&& dtype == DType :: F16
222
- && ( ( config. position_embedding_type == PositionEmbeddingType :: Absolute ) | ( config. position_embedding_type == PositionEmbeddingType :: Alibi ) )
223
224
// Allow disabling because of flash attention v1 precision problems
224
225
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
225
226
&& & std:: env:: var ( "USE_FLASH_ATTENTION" ) . unwrap_or ( "True" . to_string ( ) ) . to_lowercase ( ) == "true"
0 commit comments