1
1
/// HTTP Server logic
2
2
use crate :: http:: types:: {
3
3
DecodeRequest , DecodeResponse , EmbedAllRequest , EmbedAllResponse , EmbedRequest , EmbedResponse ,
4
- EmbedSparseRequest , EmbedSparseResponse , Input , InputIds , InputType , OpenAICompatEmbedding ,
5
- OpenAICompatErrorResponse , OpenAICompatRequest , OpenAICompatResponse , OpenAICompatUsage ,
6
- PredictInput , PredictRequest , PredictResponse , Prediction , Rank , RerankRequest , RerankResponse ,
7
- Sequence , SimpleToken , SparseValue , TokenizeInput , TokenizeRequest , TokenizeResponse ,
8
- VertexPrediction , VertexRequest , VertexResponse ,
4
+ EmbedSparseRequest , EmbedSparseResponse , Embedding , EncodingFormat , Input , InputIds , InputType ,
5
+ OpenAICompatEmbedding , OpenAICompatErrorResponse , OpenAICompatRequest , OpenAICompatResponse ,
6
+ OpenAICompatUsage , PredictInput , PredictRequest , PredictResponse , Prediction , Rank ,
7
+ RerankRequest , RerankResponse , Sequence , SimpleToken , SparseValue , TokenizeInput ,
8
+ TokenizeRequest , TokenizeResponse , VertexPrediction , VertexRequest , VertexResponse ,
9
9
} ;
10
10
use crate :: {
11
11
shutdown, ClassifierModel , EmbeddingModel , ErrorResponse , ErrorType , Info , ModelType ,
@@ -19,6 +19,8 @@ use axum::http::{Method, StatusCode};
19
19
use axum:: routing:: { get, post} ;
20
20
use axum:: { http, Json , Router } ;
21
21
use axum_tracing_opentelemetry:: middleware:: OtelAxumLayer ;
22
+ use base64:: prelude:: BASE64_STANDARD ;
23
+ use base64:: Engine ;
22
24
use futures:: future:: join_all;
23
25
use futures:: FutureExt ;
24
26
use http:: header:: AUTHORIZATION ;
@@ -938,6 +940,21 @@ async fn openai_embed(
938
940
Json ( req) : Json < OpenAICompatRequest > ,
939
941
) -> Result < ( HeaderMap , Json < OpenAICompatResponse > ) , ( StatusCode , Json < OpenAICompatErrorResponse > ) >
940
942
{
943
+ let encode_embedding = |array : Vec < f32 > | {
944
+ match req. encoding_format {
945
+ EncodingFormat :: Float => Embedding :: Float ( array) ,
946
+ EncodingFormat :: Base64 => {
947
+ // Unsafe is fine here since we do not violate memory ownership: bytes
948
+ // is only used in this scope and we return an owned string
949
+ let bytes = unsafe {
950
+ std:: slice:: from_raw_parts ( array. as_ptr ( ) as * const u8 , array. len ( ) * 4 )
951
+ } ;
952
+
953
+ Embedding :: Base64 ( BASE64_STANDARD . encode ( bytes) )
954
+ }
955
+ }
956
+ } ;
957
+
941
958
let span = tracing:: Span :: current ( ) ;
942
959
let start_time = Instant :: now ( ) ;
943
960
@@ -957,10 +974,11 @@ async fn openai_embed(
957
974
958
975
metrics:: increment_counter!( "te_request_success" , "method" => "single" ) ;
959
976
977
+ let embedding = encode_embedding ( response. results ) ;
960
978
(
961
979
vec ! [ OpenAICompatEmbedding {
962
980
object: "embedding" ,
963
- embedding: response . results ,
981
+ embedding,
964
982
index: 0 ,
965
983
} ] ,
966
984
ResponseMetadata :: new (
@@ -1033,9 +1051,10 @@ async fn openai_embed(
1033
1051
total_queue_time += r. metadata . queue . as_nanos ( ) as u64 ;
1034
1052
total_inference_time += r. metadata . inference . as_nanos ( ) as u64 ;
1035
1053
total_compute_tokens += r. metadata . prompt_tokens ;
1054
+ let embedding = encode_embedding ( r. results ) ;
1036
1055
embeddings. push ( OpenAICompatEmbedding {
1037
1056
object : "embedding" ,
1038
- embedding : r . results ,
1057
+ embedding,
1039
1058
index : i,
1040
1059
} ) ;
1041
1060
}
0 commit comments