Skip to content

Commit 13dddbd

Browse files
authored
add reranker model support for python backend (#386)
Signed-off-by: kaixuanliu <[email protected]>
1 parent d3a8098 commit 13dddbd

File tree

10 files changed

+149
-26
lines changed

10 files changed

+149
-26
lines changed

Dockerfile-intel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url
7777
RUN cd backends/python/server && \
7878
make install
7979

80-
FROM vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
80+
FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest AS hpu
8181
ENV HUGGINGFACE_HUB_CACHE=/data \
8282
PORT=80
8383

backends/grpc-client/src/client.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,4 +64,25 @@ impl Client {
6464
let response = self.stub.embed(request).await?.into_inner();
6565
Ok(response.embeddings)
6666
}
67+
68+
#[instrument(skip_all)]
69+
pub async fn predict(
70+
&mut self,
71+
input_ids: Vec<u32>,
72+
token_type_ids: Vec<u32>,
73+
position_ids: Vec<u32>,
74+
cu_seq_lengths: Vec<u32>,
75+
max_length: u32,
76+
) -> Result<Vec<Score>> {
77+
let request = tonic::Request::new(EmbedRequest {
78+
input_ids,
79+
token_type_ids,
80+
position_ids,
81+
max_length,
82+
cu_seq_lengths,
83+
})
84+
.inject_context();
85+
let response = self.stub.predict(request).await?.into_inner();
86+
Ok(response.scores)
87+
}
6788
}

backends/proto/embed.proto

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ service EmbeddingService {
77
rpc Embed (EmbedRequest) returns (EmbedResponse);
88
/// Health check
99
rpc Health (HealthRequest) returns (HealthResponse);
10+
/// Predict
11+
rpc Predict (EmbedRequest) returns (PredictResponse);
1012
}
1113

1214
message HealthRequest {}
@@ -28,3 +30,11 @@ message Embedding {
2830
message EmbedResponse {
2931
repeated Embedding embeddings = 1;
3032
}
33+
34+
message Score {
35+
repeated float values = 1;
36+
}
37+
38+
message PredictResponse {
39+
repeated Score scores = 1;
40+
}

backends/python/server/requirements-hpu.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
accelerate==0.33.0 ; python_version >= "3.9" and python_version < "3.13"
12
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
23
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
34
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -31,8 +32,8 @@ opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_versi
3132
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
3233
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
3334
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
34-
optimum-habana==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
35-
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
35+
optimum-habana==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
36+
optimum==1.23.3 ; python_version >= "3.9" and python_version < "3.13"
3637
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
3738
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
3839
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
@@ -46,8 +47,8 @@ six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
4647
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
4748
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
4849
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
49-
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
50-
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
50+
transformers==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
51+
transformers[sentencepiece]==4.45.2 ; python_version >= "3.9" and python_version < "3.13"
5152
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
5253
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
5354
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from text_embeddings_server.models.model import Model
1010
from text_embeddings_server.models.default_model import DefaultModel
11+
from text_embeddings_server.models.classification_model import ClassificationModel
1112
from text_embeddings_server.utils.device import get_device, use_ipex
1213

1314
__all__ = ["Model"]
@@ -43,18 +44,19 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
4344
if config.model_type == "bert":
4445
config: BertConfig
4546
if (
46-
device.type == "cuda"
47+
use_ipex()
48+
or device.type in ["cuda", "hpu"]
4749
and config.position_embedding_type == "absolute"
4850
and datatype in [torch.float16, torch.bfloat16]
4951
and FLASH_ATTENTION
5052
):
5153
if pool != "cls":
52-
raise ValueError("FlashBert only supports cls pooling")
53-
return FlashBert(model_path, device, datatype) # type: ignore
54-
if use_ipex() or device.type == "hpu":
55-
return FlashBert(model_path, device, datatype) # type: ignore
56-
57-
return DefaultModel(model_path, device, datatype)
54+
return DefaultModel(model_path, device, datatype, pool)
55+
return FlashBert(model_path, device, datatype)
56+
if config.architectures[0].endswith("Classification"):
57+
return ClassificationModel(model_path, device, datatype)
58+
else:
59+
return DefaultModel(model_path, device, datatype, pool)
5860
else:
5961
if device.type == "hpu":
6062
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
@@ -63,7 +65,14 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
6365
)
6466

6567
adapt_transformers_to_gaudi()
66-
model_handle = DefaultModel(model_path, device, datatype)
68+
if config.architectures[0].endswith("Classification"):
69+
model_handle = ClassificationModel(model_path, device, datatype)
70+
else:
71+
model_handle = DefaultModel(model_path, device, datatype, pool)
6772
model_handle.model = wrap_in_hpu_graph(model_handle.model)
6873
return model_handle
69-
return DefaultModel(model_path, device, datatype)
74+
elif use_ipex():
75+
if config.architectures[0].endswith("Classification"):
76+
return ClassificationModel(model_path, device, datatype)
77+
else:
78+
return DefaultModel(model_path, device, datatype, pool)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import inspect
2+
import torch
3+
4+
from pathlib import Path
5+
from typing import Type, List
6+
from transformers import AutoModelForSequenceClassification
7+
from opentelemetry import trace
8+
9+
from text_embeddings_server.models import Model
10+
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
11+
12+
tracer = trace.get_tracer(__name__)
13+
14+
15+
class ClassificationModel(Model):
16+
def __init__(self, model_path: Path, device: torch.device, dtype: torch.dtype):
17+
model = AutoModelForSequenceClassification.from_pretrained(model_path)
18+
model = model.to(dtype).to(device)
19+
20+
self.hidden_size = model.config.hidden_size
21+
self.has_position_ids = (
22+
inspect.signature(model.forward).parameters.get("position_ids", None)
23+
is not None
24+
)
25+
self.has_token_type_ids = (
26+
inspect.signature(model.forward).parameters.get("token_type_ids", None)
27+
is not None
28+
)
29+
30+
super(ClassificationModel, self).__init__(
31+
model=model, dtype=dtype, device=device
32+
)
33+
34+
@property
35+
def batch_type(self) -> Type[PaddedBatch]:
36+
return PaddedBatch
37+
38+
@tracer.start_as_current_span("embed")
39+
def embed(self, batch: PaddedBatch) -> List[Embedding]:
40+
pass
41+
42+
@tracer.start_as_current_span("predict")
43+
def predict(self, batch: PaddedBatch) -> List[Score]:
44+
kwargs = {"input_ids": batch.input_ids, "attention_mask": batch.attention_mask}
45+
if self.has_token_type_ids:
46+
kwargs["token_type_ids"] = batch.token_type_ids
47+
if self.has_position_ids:
48+
kwargs["position_ids"] = batch.position_ids
49+
50+
output = self.model(**kwargs, return_dict=True)
51+
all_scores = output.logits.tolist()
52+
return [Score(values=scores) for scores in all_scores]

backends/python/server/text_embeddings_server/models/default_model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sentence_transformers.models import Pooling
99

1010
from text_embeddings_server.models import Model
11-
from text_embeddings_server.models.types import PaddedBatch, Embedding
11+
from text_embeddings_server.models.types import PaddedBatch, Embedding, Score
1212

1313
tracer = trace.get_tracer(__name__)
1414

@@ -59,3 +59,7 @@ def embed(self, batch: PaddedBatch) -> List[Embedding]:
5959
)
6060
for i in range(len(batch))
6161
]
62+
63+
@tracer.start_as_current_span("predict")
64+
def predict(self, batch: PaddedBatch) -> List[Score]:
65+
pass

backends/python/server/text_embeddings_server/models/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from opentelemetry import trace
66

77
from text_embeddings_server.pb import embed_pb2
8-
from text_embeddings_server.pb.embed_pb2 import Embedding
8+
from text_embeddings_server.pb.embed_pb2 import Embedding, Score
99

1010
tracer = trace.get_tracer(__name__)
1111

backends/python/server/text_embeddings_server/server.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ async def Embed(self, request, context):
3131

3232
return embed_pb2.EmbedResponse(embeddings=embeddings)
3333

34+
async def Predict(self, request, context):
35+
batch = self.model.batch_type.from_pb(request, self.model.device)
36+
37+
scores = self.model.predict(batch)
38+
39+
return embed_pb2.PredictResponse(scores=scores)
40+
3441

3542
def serve(
3643
model_path: Path,

backends/python/src/lib.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use backend_grpc_client::Client;
55
use nohash_hasher::BuildNoHashHasher;
66
use std::collections::HashMap;
77
use text_embeddings_backend_core::{
8-
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Predictions,
8+
Backend, BackendError, Batch, Embedding, Embeddings, ModelType, Pool, Predictions,
99
};
1010
use tokio::runtime::Runtime;
1111

@@ -25,11 +25,7 @@ impl PythonBackend {
2525
otlp_service_name: String,
2626
) -> Result<Self, BackendError> {
2727
let pool = match model_type {
28-
ModelType::Classifier => {
29-
return Err(BackendError::Start(
30-
"`classifier` model type is not supported".to_string(),
31-
))
32-
}
28+
ModelType::Classifier => Pool::Cls,
3329
ModelType::Embedding(pool) => pool,
3430
};
3531

@@ -105,9 +101,32 @@ impl Backend for PythonBackend {
105101
Ok(embeddings)
106102
}
107103

108-
fn predict(&self, _batch: Batch) -> Result<Predictions, BackendError> {
109-
Err(BackendError::Inference(
110-
"`predict` is not implemented".to_string(),
111-
))
104+
fn predict(&self, batch: Batch) -> Result<Predictions, BackendError> {
105+
if !batch.raw_indices.is_empty() {
106+
return Err(BackendError::Inference(
107+
"raw embeddings are not supported for the Python backend.".to_string(),
108+
));
109+
}
110+
let batch_size = batch.len();
111+
let results = self
112+
.tokio_runtime
113+
.block_on(self.backend_client.clone().predict(
114+
batch.input_ids,
115+
batch.token_type_ids,
116+
batch.position_ids,
117+
batch.cumulative_seq_lengths,
118+
batch.max_length,
119+
))
120+
.map_err(|err| BackendError::Inference(err.to_string()))?;
121+
let raw_results: Vec<Vec<f32>> = results.into_iter().map(|r| r.values).collect();
122+
123+
let mut predictions =
124+
HashMap::with_capacity_and_hasher(batch_size, BuildNoHashHasher::default());
125+
126+
for (i, r) in raw_results.into_iter().enumerate() {
127+
predictions.insert(i, r);
128+
}
129+
130+
Ok(predictions)
112131
}
113132
}

0 commit comments

Comments
 (0)