Skip to content

Commit c6c5e45

Browse files
committed
Merge branch 'rocm-support' of github.com:huggingface/text-embeddings-inference into rocm-support
2 parents 839a445 + 09b8b22 commit c6c5e45

15 files changed

+6380
-3410
lines changed

Makefile

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
integration-tests:
2-
cargo test --release
2+
cargo test
33

44
cuda-integration-tests:
5-
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --release
5+
cargo test -F text-embeddings-backend-candle/cuda -F text-embeddings-backend-candle/flash-attn -F text-embeddings-router/candle-cuda --profile release-debug
66

77
integration-tests-review:
8-
cargo insta test --review --release
8+
cargo insta test --review
99

1010
cuda-integration-tests-review:
11-
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --release
11+
cargo insta test --review --features "text-embeddings-backend-candle/cuda text-embeddings-backend-candle/flash-attn text-embeddings-router/candle-cuda" --profile release-debug
File renamed without changes.

backends/candle/src/lib.rs

+67-66
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::compute_cap::{
1212
};
1313
use crate::models::{
1414
BertConfig, BertModel, DistilBertConfig, DistilBertModel, JinaBertModel, JinaCodeBertModel,
15-
JinaCodeConfig, JinaConfig, Model, NomicBertModel, NomicConfig,
15+
Model, NomicBertModel, NomicConfig,
1616
};
1717
#[cfg(feature = "cuda")]
1818
use crate::models::{
@@ -30,17 +30,28 @@ use text_embeddings_backend_core::{
3030
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
3131
};
3232

33+
/// This enum is needed to be able to differentiate between jina models that also use
34+
/// the `bert` model type and valid Bert models.
35+
/// We use the `_name_or_path` field in the config to do so. This might not be robust in the long
36+
/// run but is still better than the other options...
37+
#[derive(Debug, Clone, PartialEq, Deserialize)]
38+
#[serde(tag = "_name_or_path")]
39+
pub enum BertConfigWrapper {
40+
#[serde(rename = "jinaai/jina-bert-implementation")]
41+
JinaBert(BertConfig),
42+
#[serde(rename = "jinaai/jina-bert-v2-qk-post-norm")]
43+
JinaCodeBert(BertConfig),
44+
#[serde(untagged)]
45+
Bert(BertConfig),
46+
}
47+
3348
#[derive(Deserialize)]
3449
#[serde(tag = "model_type", rename_all = "kebab-case")]
3550
enum Config {
36-
Bert(BertConfig),
51+
Bert(BertConfigWrapper),
3752
XlmRoberta(BertConfig),
3853
Camembert(BertConfig),
3954
Roberta(BertConfig),
40-
#[serde(rename(deserialize = "jina_bert"))]
41-
JinaBert(JinaConfig),
42-
#[serde(rename(deserialize = "jina_code_bert"))]
43-
JinaCodeBert(JinaCodeConfig),
4455
#[serde(rename(deserialize = "distilbert"))]
4556
DistilBert(DistilBertConfig),
4657
#[serde(rename(deserialize = "nomic_bert"))]
@@ -76,7 +87,7 @@ impl CandleBackend {
7687
"Runtime compute cap {} is not compatible with compile time compute cap {}",
7788
get_runtime_compute_cap().unwrap(),
7889
get_compile_compute_cap().unwrap()
79-
)))
90+
)));
8091
}
8192
Err(err) => {
8293
tracing::warn!("Could not find a compatible CUDA device on host: {err:?}");
@@ -123,20 +134,22 @@ impl CandleBackend {
123134
(_, Device::Cuda(_)) => Err(BackendError::Start(
124135
"`cuda` feature is not enabled".to_string(),
125136
)),
126-
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => {
127-
tracing::info!("Starting Bert model on {:?}", device);
128-
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
129-
}
130-
(Config::JinaBert(config), Device::Cpu | Device::Metal(_)) => {
131-
tracing::info!("Starting JinaBertModel model on {:?}", device);
132-
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
133-
}
134-
(Config::JinaCodeBert(config), Device::Cpu | Device::Metal(_)) => {
135-
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
136-
Ok(Box::new(
137-
JinaCodeBertModel::load(vb, &config, model_type).s()?,
138-
))
139-
}
137+
(Config::Bert(config), Device::Cpu | Device::Metal(_)) => match config {
138+
BertConfigWrapper::JinaBert(config) => {
139+
tracing::info!("Starting JinaBertModel model on {:?}", device);
140+
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
141+
}
142+
BertConfigWrapper::JinaCodeBert(config) => {
143+
tracing::info!("Starting JinaCodeBert model on {:?}", device);
144+
Ok(Box::new(
145+
JinaCodeBertModel::load(vb, &config, model_type).s()?,
146+
))
147+
}
148+
BertConfigWrapper::Bert(config) => {
149+
tracing::info!("Starting Bert model on {:?}", device);
150+
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
151+
}
152+
},
140153
(
141154
Config::XlmRoberta(config) | Config::Camembert(config) | Config::Roberta(config),
142155
Device::Cpu | Device::Metal(_),
@@ -160,56 +173,45 @@ impl CandleBackend {
160173
(Config::Bert(config), Device::Cuda(_)) => {
161174
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
162175
&& dtype == DType::F16
163-
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
164176
// Allow disabling because of flash attention v1 precision problems
165177
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
166178
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
167179
{
168-
if config.position_embedding_type == PositionEmbeddingType::Alibi {
169-
tracing::info!("Starting FlashBert model on {:?}", device);
170-
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
171-
} else {
172-
tracing::info!("Starting Bert model on {:?}", device);
173-
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
180+
match config {
181+
BertConfigWrapper::JinaBert(config) => {
182+
tracing::info!("Starting FlashJinaBert model on {:?}", device);
183+
Ok(Box::new(
184+
FlashJinaBertModel::load(vb, &config, model_type).s()?,
185+
))
186+
}
187+
BertConfigWrapper::JinaCodeBert(config) => {
188+
tracing::info!("Starting FlashJinaCodeBert model on {:?}", device);
189+
Ok(Box::new(
190+
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
191+
))
192+
}
193+
BertConfigWrapper::Bert(config) => {
194+
tracing::info!("Starting FlashBert model on {:?}", device);
195+
Ok(Box::new(FlashBertModel::load(vb, &config, model_type).s()?))
196+
}
174197
}
175-
}
176-
}
177-
#[cfg(feature = "cuda")]
178-
(Config::JinaBert(config), Device::Cuda(_)) => {
179-
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
180-
&& dtype == DType::F16
181-
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
182-
// Allow disabling because of flash attention v1 precision problems
183-
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
184-
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
185-
{
186-
tracing::info!("Starting FlashJinaBertModel model on {:?}", device);
187-
Ok(Box::new(
188-
FlashJinaBertModel::load(vb, &config, model_type).s()?,
189-
))
190-
} else {
191-
tracing::info!("Starting JinaBertModel model on {:?}", device);
192-
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
193-
}
194-
}
195-
#[cfg(feature = "cuda")]
196-
(Config::JinaCodeBert(config), Device::Cuda(_)) => {
197-
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
198-
&& dtype == DType::F16
199-
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
200-
// Allow disabling because of flash attention v1 precision problems
201-
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
202-
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"
203-
{
204-
tracing::info!("Starting FlashJinaCodeBertModel model on {:?}", device);
205-
Ok(Box::new(
206-
FlashJinaCodeBertModel::load(vb, &config, model_type).s()?,
207-
))
208198
} else {
209-
tracing::info!("Starting JinaCodeBertModel model on {:?}", device);
210-
Ok(Box::new(
211-
JinaCodeBertModel::load(vb, &config, model_type).s()?,
212-
))
199+
match config {
200+
BertConfigWrapper::JinaBert(config) => {
201+
tracing::info!("Starting JinaBertModel model on {:?}", device);
202+
Ok(Box::new(JinaBertModel::load(vb, &config, model_type).s()?))
203+
}
204+
BertConfigWrapper::JinaCodeBert(config) => {
205+
tracing::info!("Starting JinaCodeBert model on {:?}", device);
206+
Ok(Box::new(
207+
JinaCodeBertModel::load(vb, &config, model_type).s()?,
208+
))
209+
}
210+
BertConfigWrapper::Bert(config) => {
211+
tracing::info!("Starting Bert model on {:?}", device);
212+
Ok(Box::new(BertModel::load(vb, &config, model_type).s()?))
213+
}
214+
}
213215
}
214216
}
215217
#[cfg(feature = "cuda")]
@@ -219,7 +221,6 @@ impl CandleBackend {
219221
) => {
220222
if cfg!(any(feature = "flash-attn", feature = "flash-attn-v1"))
221223
&& dtype == DType::F16
222-
&& ((config.position_embedding_type == PositionEmbeddingType::Absolute) | (config.position_embedding_type == PositionEmbeddingType::Alibi))
223224
// Allow disabling because of flash attention v1 precision problems
224225
// See: https://github.com/huggingface/text-embeddings-inference/issues/37
225226
&& &std::env::var("USE_FLASH_ATTENTION").unwrap_or("True".to_string()).to_lowercase() == "true"

backends/candle/src/models/flash_jina.rs

+17-18
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ use crate::alibi::alibi_head_slopes;
22
use crate::flash_attn::flash_attn_varlen;
33
use crate::layers::{HiddenAct, LayerNorm, Linear};
44
use crate::models::bert::PositionEmbeddingType;
5-
use crate::models::jina::BertEmbeddings;
6-
use crate::models::jina::{BertEmbeddings, JinaConfig};
7-
use crate::models::Model;
5+
use crate::models::jina::JinaEmbeddings;
6+
use crate::models::{BertConfig, Model};
87
use candle::{DType, Device, IndexOp, Result, Tensor};
98
use candle_nn::VarBuilder;
109
use text_embeddings_backend_core::{Batch, ModelType, Pool};
1110

12-
struct AlibiBertAttention {
11+
struct JinaAttention {
1312
qkv_linear: Linear,
1413
dense: Linear,
1514
layer_norm: LayerNorm,
@@ -23,7 +22,7 @@ struct AlibiBertAttention {
2322
span: tracing::Span,
2423
}
2524

26-
impl AlibiBertAttention {
25+
impl JinaAttention {
2726
pub fn load(vb: VarBuilder, config: &BertConfig, alibi_slopes: Option<Tensor>) -> Result<Self> {
2827
let attention_head_size = config.hidden_size / config.num_attention_heads;
2928
let all_head_size = config.num_attention_heads * attention_head_size;
@@ -117,7 +116,7 @@ impl AlibiBertAttention {
117116
}
118117

119118
struct JinaBertLayer {
120-
attention: AlibiBertAttention,
119+
attention: JinaAttention,
121120
gated: Linear,
122121
output: Linear,
123122
layer_norm: LayerNorm,
@@ -130,7 +129,7 @@ struct JinaBertLayer {
130129

131130
impl JinaBertLayer {
132131
pub fn load(vb: VarBuilder, config: &BertConfig, alibi: Option<Tensor>) -> Result<Self> {
133-
let attention = AlibiBertAttention::load(vb.pp("attention"), config, alibi)?;
132+
let attention = JinaAttention::load(vb.pp("attention"), config, alibi)?;
134133

135134
let gated_weight = vb
136135
.pp("mlp")
@@ -174,14 +173,14 @@ impl JinaBertLayer {
174173
let residual = hidden_states.clone();
175174

176175
let hidden_states = self.gated.forward(&hidden_states)?;
177-
let gated = hidden_states.i((.., 0..self.intermediate_size))?;
176+
let gated = hidden_states.narrow(1, 0, self.intermediate_size)?;
178177
let gated = match self.act {
179178
HiddenAct::Gelu => gated.gelu(),
180179
HiddenAct::Relu => gated.relu(),
181180
HiddenAct::Swiglu => gated.silu(),
182181
}?;
183182

184-
let non_gated = hidden_states.i((.., self.intermediate_size..))?;
183+
let non_gated = hidden_states.narrow(1, self.intermediate_size, self.intermediate_size)?;
185184
let hidden_states = (gated * non_gated)?;
186185

187186
let hidden_states = self.output.forward(&hidden_states)?;
@@ -191,12 +190,12 @@ impl JinaBertLayer {
191190
}
192191
}
193192

194-
struct BertEncoder {
193+
struct JinaBertEncoder {
195194
layers: Vec<JinaBertLayer>,
196195
span: tracing::Span,
197196
}
198197

199-
impl BertEncoder {
198+
impl JinaBertEncoder {
200199
pub fn load(vb: VarBuilder, config: &BertConfig, alibi: Option<Tensor>) -> Result<Self> {
201200
let layers = (0..config.num_hidden_layers)
202201
.map(|index| {
@@ -205,7 +204,7 @@ impl BertEncoder {
205204
.collect::<Result<Vec<_>>>()?;
206205
let span = tracing::span!(tracing::Level::TRACE, "encoder");
207206

208-
Ok(BertEncoder { layers, span })
207+
Ok(JinaBertEncoder { layers, span })
209208
}
210209

211210
fn forward(&self, hidden_states: &Tensor, cu_seqlens: &Tensor, max_s: usize) -> Result<Tensor> {
@@ -223,8 +222,8 @@ impl BertEncoder {
223222
}
224223

225224
pub struct FlashJinaBertModel {
226-
embeddings: BertEmbeddings,
227-
encoder: BertEncoder,
225+
embeddings: JinaEmbeddings,
226+
encoder: JinaBertEncoder,
228227
pool: Pool,
229228
pub device: Device,
230229

@@ -266,14 +265,14 @@ impl FlashJinaBertModel {
266265
};
267266

268267
let (embeddings, encoder) = match (
269-
BertEmbeddings::load(vb.pp("embeddings"), config),
270-
BertEncoder::load(vb.pp("encoder"), config, alibi.clone()),
268+
JinaEmbeddings::load(vb.pp("embeddings"), config),
269+
JinaBertEncoder::load(vb.pp("encoder"), config, alibi.clone()),
271270
) {
272271
(Ok(embeddings), Ok(encoder)) => (embeddings, encoder),
273272
(Err(err), _) | (_, Err(err)) => {
274273
if let (Ok(embeddings), Ok(encoder)) = (
275-
BertEmbeddings::load(vb.pp("bert.embeddings"), config),
276-
BertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()),
274+
JinaEmbeddings::load(vb.pp("bert.embeddings"), config),
275+
JinaBertEncoder::load(vb.pp("bert.encoder"), config, alibi.clone()),
277276
) {
278277
(embeddings, encoder)
279278
} else {

0 commit comments

Comments
 (0)