Skip to content

Commit 1718a7a

Browse files
feat(router): add base64 encoding_format for OpenAI API (#301)
1 parent b8f6c78 commit 1718a7a

File tree

5 files changed

+48
-10
lines changed

5 files changed

+48
-10
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

-1
Original file line numberDiff line numberDiff line change
@@ -38,4 +38,3 @@ debug = 1
3838
lto = "thin"
3939
codegen-units = 16
4040
strip = "none"
41-
incremental = true

router/Cargo.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ mimalloc = { version = "*", default-features = false }
4343
# HTTP dependencies
4444
axum = { version = "0.7.4", features = ["json"], optional = true }
4545
axum-tracing-opentelemetry = { version = "0.17.0", optional = true }
46+
base64 = { version = "0.21.4", optional = true }
4647
tower-http = { version = "0.5.1", features = ["cors"], optional = true }
4748
utoipa = { version = "4.2", features = ["axum_extras"], optional = true }
4849
utoipa-swagger-ui = { version = "6.0", features = ["axum"], optional = true }
@@ -66,7 +67,7 @@ tonic-build = { version = "0.10.2", optional = true }
6667

6768
[features]
6869
default = ["candle", "http"]
69-
http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"]
70+
http = ["dep:axum", "dep:axum-tracing-opentelemetry", "dep:base64", "dep:tower-http", "dep:utoipa", "dep:utoipa-swagger-ui"]
7071
grpc = ["metrics-exporter-prometheus/http-listener", "dep:prost", "dep:tonic", "dep:tonic-health", "dep:tonic-reflection", "dep:tonic-build", "dep:async-stream", "dep:tokio-stream"]
7172
metal = ["text-embeddings-backend/metal"]
7273
mkl = ["text-embeddings-backend/mkl"]

router/src/http/server.rs

+26-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
/// HTTP Server logic
22
use crate::http::types::{
33
DecodeRequest, DecodeResponse, EmbedAllRequest, EmbedAllResponse, EmbedRequest, EmbedResponse,
4-
EmbedSparseRequest, EmbedSparseResponse, Input, InputIds, InputType, OpenAICompatEmbedding,
5-
OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse, OpenAICompatUsage,
6-
PredictInput, PredictRequest, PredictResponse, Prediction, Rank, RerankRequest, RerankResponse,
7-
Sequence, SimpleToken, SparseValue, TokenizeInput, TokenizeRequest, TokenizeResponse,
8-
VertexPrediction, VertexRequest, VertexResponse,
4+
EmbedSparseRequest, EmbedSparseResponse, Embedding, EncodingFormat, Input, InputIds, InputType,
5+
OpenAICompatEmbedding, OpenAICompatErrorResponse, OpenAICompatRequest, OpenAICompatResponse,
6+
OpenAICompatUsage, PredictInput, PredictRequest, PredictResponse, Prediction, Rank,
7+
RerankRequest, RerankResponse, Sequence, SimpleToken, SparseValue, TokenizeInput,
8+
TokenizeRequest, TokenizeResponse, VertexPrediction, VertexRequest, VertexResponse,
99
};
1010
use crate::{
1111
shutdown, ClassifierModel, EmbeddingModel, ErrorResponse, ErrorType, Info, ModelType,
@@ -19,6 +19,8 @@ use axum::http::{Method, StatusCode};
1919
use axum::routing::{get, post};
2020
use axum::{http, Json, Router};
2121
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
22+
use base64::prelude::BASE64_STANDARD;
23+
use base64::Engine;
2224
use futures::future::join_all;
2325
use futures::FutureExt;
2426
use http::header::AUTHORIZATION;
@@ -938,6 +940,21 @@ async fn openai_embed(
938940
Json(req): Json<OpenAICompatRequest>,
939941
) -> Result<(HeaderMap, Json<OpenAICompatResponse>), (StatusCode, Json<OpenAICompatErrorResponse>)>
940942
{
943+
let encode_embedding = |array: Vec<f32>| {
944+
match req.encoding_format {
945+
EncodingFormat::Float => Embedding::Float(array),
946+
EncodingFormat::Base64 => {
947+
// Unsafe is fine here since we do not violate memory ownership: bytes
948+
// is only used in this scope and we return an owned string
949+
let bytes = unsafe {
950+
std::slice::from_raw_parts(array.as_ptr() as *const u8, array.len() * 4)
951+
};
952+
953+
Embedding::Base64(BASE64_STANDARD.encode(bytes))
954+
}
955+
}
956+
};
957+
941958
let span = tracing::Span::current();
942959
let start_time = Instant::now();
943960

@@ -957,10 +974,11 @@ async fn openai_embed(
957974

958975
metrics::increment_counter!("te_request_success", "method" => "single");
959976

977+
let embedding = encode_embedding(response.results);
960978
(
961979
vec![OpenAICompatEmbedding {
962980
object: "embedding",
963-
embedding: response.results,
981+
embedding,
964982
index: 0,
965983
}],
966984
ResponseMetadata::new(
@@ -1033,9 +1051,10 @@ async fn openai_embed(
10331051
total_queue_time += r.metadata.queue.as_nanos() as u64;
10341052
total_inference_time += r.metadata.inference.as_nanos() as u64;
10351053
total_compute_tokens += r.metadata.prompt_tokens;
1054+
let embedding = encode_embedding(r.results);
10361055
embeddings.push(OpenAICompatEmbedding {
10371056
object: "embedding",
1038-
embedding: r.results,
1057+
embedding,
10391058
index: i,
10401059
});
10411060
}

router/src/http/types.rs

+19-1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,14 @@ pub(crate) enum Input {
285285
Batch(Vec<InputType>),
286286
}
287287

288+
#[derive(Deserialize, ToSchema, Default)]
289+
#[serde(rename_all = "snake_case")]
290+
pub(crate) enum EncodingFormat {
291+
#[default]
292+
Float,
293+
Base64,
294+
}
295+
288296
#[derive(Deserialize, ToSchema)]
289297
pub(crate) struct OpenAICompatRequest {
290298
pub input: Input,
@@ -294,14 +302,24 @@ pub(crate) struct OpenAICompatRequest {
294302
#[allow(dead_code)]
295303
#[schema(nullable = true, example = "null")]
296304
pub user: Option<String>,
305+
#[schema(default = "float", example = "float")]
306+
#[serde(default)]
307+
pub encoding_format: EncodingFormat,
308+
}
309+
310+
#[derive(Serialize, ToSchema)]
311+
#[serde(untagged)]
312+
pub(crate) enum Embedding {
313+
Float(Vec<f32>),
314+
Base64(String),
297315
}
298316

299317
#[derive(Serialize, ToSchema)]
300318
pub(crate) struct OpenAICompatEmbedding {
301319
#[schema(example = "embedding")]
302320
pub object: &'static str,
303321
#[schema(example = json!([0.0, 1.0, 2.0]))]
304-
pub embedding: Vec<f32>,
322+
pub embedding: Embedding,
305323
#[schema(example = "0")]
306324
pub index: usize,
307325
}

0 commit comments

Comments
 (0)