Skip to content

Commit d4e6e54

Browse files
committed
Store sparse vector and decoded token weights
Why these changes are being introduced: Formerly, our 'Embedding' class only had an 'embedding' property for the output. However, for our first model in the pipeline, opensearch-project/ opensearch-neural-sparse-encoding-doc-v3-gte, it produces two representations of the embedding that are useful to store: a sparse vector and decoded token weights. How this addresses that need: Updates the 'Embedding' class to explicitly store both representations of the embedding. We may decide that we don't store both, or some futures models may not produce decoded token weights of any kind, but this matches our first proposed model and pipeline. Better to be explicit and opinionated in these early days, then adjust later if needed. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-112
1 parent de351a1 commit d4e6e54

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

embeddings/cli.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,8 @@ def create_embeddings(
258258
run_record_offset=input_record.run_record_offset,
259259
embedding_strategy=input_record.embedding_strategy,
260260
model_uri=model.model_uri,
261-
embedding={"coffee": 0.9, "seattle": 0.5},
261+
embedding_vector=[0.1, 0.2, 0.3],
262+
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
262263
)
263264
for input_record in input_records
264265
)

embeddings/embedding.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,18 @@ class Embedding:
3232
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
3333
model_uri: model URI used to create the embedding
3434
embedding_strategy: strategy used to create text for embedding
35-
embedding: model embedding created from text
35+
embedding_vector: vector representation of embedding
36+
embedding_token_weights: decoded token:weight pairs from sparse vector
37+
- only applicable to models that produce this output
3638
"""
3739

3840
timdex_record_id: str
3941
run_id: str
4042
run_record_offset: int
4143
model_uri: str
4244
embedding_strategy: str
43-
embedding: dict | list[float]
45+
embedding_vector: list[float]
46+
embedding_token_weights: dict
4447

4548
timestamp: datetime.datetime = field(
4649
default_factory=lambda: datetime.datetime.now(datetime.UTC)

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
5252
run_record_offset=input_record.run_record_offset,
5353
embedding_strategy=input_record.embedding_strategy,
5454
model_uri=self.model_uri,
55-
embedding={"coffee": 0.9, "seattle": 0.5},
55+
embedding_vector=[0.1, 0.2, 0.3],
56+
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
5657
)
5758

5859

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@ def test_mock_model_create_embedding(mock_model):
4949
assert embedding.run_record_offset == 42
5050
assert embedding.embedding_strategy == "full_record"
5151
assert embedding.model_uri == "test/mock-model"
52-
assert embedding.embedding == {"coffee": 0.9, "seattle": 0.5}
52+
assert embedding.embedding_vector == [0.1, 0.2, 0.3]
53+
assert embedding.embedding_token_weights == {"coffee": 0.9, "seattle": 0.5}
5354

5455

5556
def test_registry_contains_opensearch_model():

0 commit comments

Comments
 (0)