Skip to content

Support MRL (Matryoshka Representation Learning) #676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions core/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,16 @@ impl Infer {
panic!("unexpected enum variant")
};

// Timings
let total_time = start_time.elapsed();

// Metrics
let counter = metrics::counter!("te_embed_success");
counter.increment(1);
let histogram = metrics::histogram!("te_embed_duration");
histogram.record(total_time.as_secs_f64());
let histogram = metrics::histogram!("te_embed_tokenization_duration");
histogram.record(response.metadata.tokenization.as_secs_f64());
let histogram = metrics::histogram!("te_embed_queue_duration");
histogram.record(response.metadata.queue.as_secs_f64());
let histogram = metrics::histogram!("te_embed_inference_duration");
histogram.record(response.metadata.inference.as_secs_f64());
metrics::counter!("te_embed_success").increment(1);
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64());
metrics::histogram!("te_embed_tokenization_duration")
.record(response.metadata.tokenization.as_secs_f64());
metrics::histogram!("te_embed_queue_duration")
.record(response.metadata.queue.as_secs_f64());
metrics::histogram!("te_embed_inference_duration")
.record(response.metadata.inference.as_secs_f64());

Ok(response)
}
Expand Down Expand Up @@ -224,6 +220,7 @@ impl Infer {
Ok(response)
}

#[allow(clippy::too_many_arguments)]
#[instrument(skip(self, inputs, permit))]
pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>(
&self,
Expand All @@ -232,20 +229,29 @@ impl Infer {
truncation_direction: TruncationDirection,
prompt_name: Option<String>,
normalize: bool,
dimensions: Option<usize>,
permit: OwnedSemaphorePermit,
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> {
let start_time = Instant::now();

if self.is_splade() && normalize {
let counter = metrics::counter!("te_request_failure", "err" => "model_type");
counter.increment(1);
metrics::counter!("te_request_failure", "err" => "model_type").increment(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sam here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted 9131071

let message = "`normalize` is not available for SPLADE models".to_string();
tracing::error!("{message}");
return Err(TextEmbeddingsError::Backend(BackendError::Inference(
message,
)));
}

if let Some(dimensions) = dimensions {
if dimensions == 0 {
metrics::counter!("te_request_failure", "err" => "validation").increment(1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't they also be smaller than the maximum embedding dimension ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that time, I was considering silently returning the embedding with the original size when the given dimension is larger than the expected size, like line 274.

On second thought, like you mentioned, it'd be better to raise an error explicitly when the size is larger than expected in terms of validity.

I'll add an extra validation logic to check whether the given size is larger than the size of the embedding. thanks for catching this!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a1b1a26

let message = "`dimensions` should be positive".to_string();
tracing::error!("{message}");
return Err(TextEmbeddingsError::Validation(message));
}
}

let results = self
.embed(
inputs,
Expand All @@ -262,6 +268,11 @@ impl Infer {
panic!("unexpected enum variant")
};

if let Some(mrl_dimensions) = dimensions {
let mrl_dimensions = mrl_dimensions.min(response.results.len());
response.results.truncate(mrl_dimensions);
}

if normalize {
// Normalize embedding
let scale = (1.0
Expand All @@ -283,16 +294,14 @@ impl Infer {
let total_time = start_time.elapsed();

// Metrics
let counter = metrics::counter!("te_embed_success");
counter.increment(1);
let histogram = metrics::histogram!("te_embed_duration");
histogram.record(total_time.as_secs_f64());
let histogram = metrics::histogram!("te_embed_tokenization_duration");
histogram.record(response.metadata.tokenization.as_secs_f64());
let histogram = metrics::histogram!("te_embed_queue_duration");
histogram.record(response.metadata.queue.as_secs_f64());
let histogram = metrics::histogram!("te_embed_inference_duration");
histogram.record(response.metadata.inference.as_secs_f64());
metrics::counter!("te_embed_success").increment(1);
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64());
metrics::histogram!("te_embed_tokenization_duration")
.record(response.metadata.tokenization.as_secs_f64());
metrics::histogram!("te_embed_queue_duration")
.record(response.metadata.queue.as_secs_f64());
metrics::histogram!("te_embed_inference_duration")
.record(response.metadata.inference.as_secs_f64());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert those changes ? They don't seem linked to the PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reverted 9131071


Ok(response)
}
Expand Down
1 change: 1 addition & 0 deletions proto/tei.proto
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ message EmbedRequest {
bool normalize = 3;
TruncationDirection truncation_direction = 4;
optional string prompt_name = 5;
optional uint32 dimensions = 6;
}

message EmbedResponse {
Expand Down
1 change: 1 addition & 0 deletions router/src/grpc/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ impl TextEmbeddingsService {
truncation_direction,
request.prompt_name,
request.normalize,
request.dimensions.map(|v| v as usize),
permit,
)
.await
Expand Down
5 changes: 5 additions & 0 deletions router/src/http/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ async fn similarity(
truncation_direction: parameters.truncation_direction,
prompt_name: parameters.prompt_name,
normalize: false,
dimensions: None,
};

// Get embeddings
Expand Down Expand Up @@ -611,6 +612,7 @@ async fn embed(
req.truncation_direction.into(),
req.prompt_name,
req.normalize,
req.dimensions,
permit,
)
.await
Expand Down Expand Up @@ -679,6 +681,7 @@ async fn embed(
req.truncation_direction.into(),
prompt_name,
req.normalize,
req.dimensions,
permit,
)
.await
Expand Down Expand Up @@ -1156,6 +1159,7 @@ async fn openai_embed(
tokenizers::TruncationDirection::Right,
None,
true,
req.dimensions,
permit,
)
.await
Expand Down Expand Up @@ -1228,6 +1232,7 @@ async fn openai_embed(
tokenizers::TruncationDirection::Right,
None,
true,
req.dimensions,
permit,
)
.await
Expand Down
11 changes: 11 additions & 0 deletions router/src/http/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,8 @@ pub(crate) struct OpenAICompatRequest {
#[schema(default = "float", example = "float")]
#[serde(default)]
pub encoding_format: EncodingFormat,
#[schema(default = "null", example = "null", nullable = true)]
pub dimensions: Option<usize>,
}

#[derive(Serialize, ToSchema)]
Expand Down Expand Up @@ -406,12 +408,15 @@ pub(crate) struct SimilarityResponse(pub Vec<f32>);
#[derive(Deserialize, ToSchema)]
pub(crate) struct EmbedRequest {
pub inputs: Input,

#[serde(default)]
#[schema(default = "false", example = "false", nullable = true)]
pub truncate: Option<bool>,

#[serde(default)]
#[schema(default = "right", example = "right")]
pub truncation_direction: TruncationDirection,

/// The name of the prompt that should be used by for encoding. If not set, no prompt
/// will be applied.
///
Expand All @@ -423,9 +428,15 @@ pub(crate) struct EmbedRequest {
/// any text to encode.
#[schema(default = "null", example = "null", nullable = true)]
pub prompt_name: Option<String>,

#[serde(default = "default_normalize")]
#[schema(default = "true", example = "true")]
pub normalize: bool,

/// The number of dimensions that the output embeddings should have. If not set, the original
/// shape of the representation will be returned instead.
#[schema(default = "null", example = "null", nullable = true)]
pub dimensions: Option<usize>,
}

fn default_normalize() -> bool {
Expand Down
Loading
Loading