-
Notifications
You must be signed in to change notification settings - Fork 288
Support MRL (Matryoshka Representation Learning) #676
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
96cb307
37439f5
b5c9024
dd87f5c
3336a63
d56c296
9131071
a1b1a26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -151,20 +151,16 @@ impl Infer { | |
panic!("unexpected enum variant") | ||
}; | ||
|
||
// Timings | ||
let total_time = start_time.elapsed(); | ||
|
||
// Metrics | ||
let counter = metrics::counter!("te_embed_success"); | ||
counter.increment(1); | ||
let histogram = metrics::histogram!("te_embed_duration"); | ||
histogram.record(total_time.as_secs_f64()); | ||
let histogram = metrics::histogram!("te_embed_tokenization_duration"); | ||
histogram.record(response.metadata.tokenization.as_secs_f64()); | ||
let histogram = metrics::histogram!("te_embed_queue_duration"); | ||
histogram.record(response.metadata.queue.as_secs_f64()); | ||
let histogram = metrics::histogram!("te_embed_inference_duration"); | ||
histogram.record(response.metadata.inference.as_secs_f64()); | ||
metrics::counter!("te_embed_success").increment(1); | ||
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64()); | ||
metrics::histogram!("te_embed_tokenization_duration") | ||
.record(response.metadata.tokenization.as_secs_f64()); | ||
metrics::histogram!("te_embed_queue_duration") | ||
.record(response.metadata.queue.as_secs_f64()); | ||
metrics::histogram!("te_embed_inference_duration") | ||
.record(response.metadata.inference.as_secs_f64()); | ||
|
||
Ok(response) | ||
} | ||
|
@@ -224,6 +220,7 @@ impl Infer { | |
Ok(response) | ||
} | ||
|
||
#[allow(clippy::too_many_arguments)] | ||
#[instrument(skip(self, inputs, permit))] | ||
pub async fn embed_pooled<I: Into<EncodingInput> + std::fmt::Debug>( | ||
&self, | ||
|
@@ -232,20 +229,29 @@ impl Infer { | |
truncation_direction: TruncationDirection, | ||
prompt_name: Option<String>, | ||
normalize: bool, | ||
dimensions: Option<usize>, | ||
permit: OwnedSemaphorePermit, | ||
) -> Result<PooledEmbeddingsInferResponse, TextEmbeddingsError> { | ||
let start_time = Instant::now(); | ||
|
||
if self.is_splade() && normalize { | ||
let counter = metrics::counter!("te_request_failure", "err" => "model_type"); | ||
counter.increment(1); | ||
metrics::counter!("te_request_failure", "err" => "model_type").increment(1); | ||
let message = "`normalize` is not available for SPLADE models".to_string(); | ||
tracing::error!("{message}"); | ||
return Err(TextEmbeddingsError::Backend(BackendError::Inference( | ||
message, | ||
))); | ||
} | ||
|
||
if let Some(dimensions) = dimensions { | ||
if dimensions == 0 { | ||
metrics::counter!("te_request_failure", "err" => "validation").increment(1); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't they also be smaller than the maximum embedding dimension ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. At that time, I was considering silently returning the embedding with the original size when the given dimension is larger than the expected size, like line 274. On second thought, like you mentioned, it'd be better to raise an error explicitly when the size is larger than expected in terms of validity. I'll add an extra validation logic to check whether the given size is larger than the size of the embedding. thanks for catching this! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a1b1a26 |
||
let message = "`dimensions` should be positive".to_string(); | ||
tracing::error!("{message}"); | ||
return Err(TextEmbeddingsError::Validation(message)); | ||
} | ||
} | ||
|
||
let results = self | ||
.embed( | ||
inputs, | ||
|
@@ -262,6 +268,11 @@ impl Infer { | |
panic!("unexpected enum variant") | ||
}; | ||
|
||
if let Some(mrl_dimensions) = dimensions { | ||
let mrl_dimensions = mrl_dimensions.min(response.results.len()); | ||
response.results.truncate(mrl_dimensions); | ||
} | ||
|
||
if normalize { | ||
// Normalize embedding | ||
let scale = (1.0 | ||
|
@@ -283,16 +294,14 @@ impl Infer { | |
let total_time = start_time.elapsed(); | ||
|
||
// Metrics | ||
let counter = metrics::counter!("te_embed_success"); | ||
counter.increment(1); | ||
let histogram = metrics::histogram!("te_embed_duration"); | ||
histogram.record(total_time.as_secs_f64()); | ||
let histogram = metrics::histogram!("te_embed_tokenization_duration"); | ||
histogram.record(response.metadata.tokenization.as_secs_f64()); | ||
let histogram = metrics::histogram!("te_embed_queue_duration"); | ||
histogram.record(response.metadata.queue.as_secs_f64()); | ||
let histogram = metrics::histogram!("te_embed_inference_duration"); | ||
histogram.record(response.metadata.inference.as_secs_f64()); | ||
metrics::counter!("te_embed_success").increment(1); | ||
metrics::histogram!("te_embed_duration").record(total_time.as_secs_f64()); | ||
metrics::histogram!("te_embed_tokenization_duration") | ||
.record(response.metadata.tokenization.as_secs_f64()); | ||
metrics::histogram!("te_embed_queue_duration") | ||
.record(response.metadata.queue.as_secs_f64()); | ||
metrics::histogram!("te_embed_inference_duration") | ||
.record(response.metadata.inference.as_secs_f64()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you revert those changes ? They don't seem linked to the PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reverted 9131071 |
||
|
||
Ok(response) | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sam here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reverted 9131071