Skip to content

Commit ce181df

Browse files
committed
Stub CLI command methods to create embeddings
How this addresses that need: * CLI command create-embeddings created * args and some functionality in place * WIP comments and DEBUG code temporarily added to demonstrate how it will work * class RecordText added to encapsulate text that is ready for an embedding * this will support future functionality of pre-embedding "strategies" applied to records * class Embedding created to encapsulate the embedding result * this captures the TIMDEX record the embedding was assocaited with, and the model + strategy used to prepare the text Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-112
1 parent 39ef93e commit ce181df

File tree

12 files changed

+362
-33
lines changed

12 files changed

+362
-33
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,5 @@ cython_debug/
155155
.DS_Store
156156
output/
157157
.vscode/
158+
159+
CLAUDE.md

embeddings/cli.py

Lines changed: 135 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import TYPE_CHECKING
88

99
import click
10+
import jsonlines
11+
from timdex_dataset_api import TIMDEXDataset
1012

1113
from embeddings.config import configure_logger, configure_sentry
1214
from embeddings.models.registry import get_model_class
@@ -150,8 +152,139 @@ def test_model_load(ctx: click.Context) -> None:
150152
@main.command()
151153
@click.pass_context
152154
@model_required
153-
def create_embedding(ctx: click.Context) -> None:
154-
"""Create a single embedding for a single input text."""
155+
@click.option(
156+
"-d",
157+
"--dataset-location",
158+
required=True,
159+
type=click.Path(),
160+
help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.",
161+
)
162+
@click.option(
163+
"--run-id",
164+
required=True,
165+
type=str,
166+
help="TIMDEX ETL run id.",
167+
)
168+
@click.option(
169+
"--run-record-offset",
170+
required=True,
171+
type=int,
172+
default=0,
173+
help="TIMDEX ETL run record offset to start from, default = 0.",
174+
)
175+
@click.option(
176+
"--record-limit",
177+
required=True,
178+
type=int,
179+
default=None,
180+
help="Limit number of records after --run-record-offset, default = None (unlimited).",
181+
)
182+
@click.option(
183+
"--strategy",
184+
type=str, # WIP: establish an enum of supported strategies
185+
required=True,
186+
multiple=True,
187+
help="Pre-embedding record transformation strategy to use. Repeatable.",
188+
)
189+
@click.option(
190+
"--output-jsonl",
191+
required=False,
192+
type=str,
193+
default=None,
194+
help="Optionally write embeddings to local JSONLines file (primarily for testing).",
195+
)
196+
def create_embeddings(
197+
ctx: click.Context,
198+
dataset_location: str,
199+
run_id: str,
200+
run_record_offset: int,
201+
record_limit: int,
202+
strategy: list[str],
203+
output_jsonl: str,
204+
) -> None:
205+
"""Create embeddings for TIMDEX records."""
206+
model: BaseEmbeddingModel = ctx.obj["model"]
207+
208+
# init TIMDEXDataset
209+
timdex_dataset = TIMDEXDataset(dataset_location)
210+
211+
# query TIMDEX dataset for an iterator of records
212+
timdex_records = timdex_dataset.read_dicts_iter(
213+
columns=[
214+
"timdex_record_id",
215+
"run_id",
216+
"run_record_offset",
217+
"transformed_record",
218+
],
219+
run_id=run_id,
220+
where=f"""run_record_offset >= {run_record_offset}""",
221+
limit=record_limit,
222+
action="index",
223+
)
224+
225+
# create an iterator of InputTexts applying all requested strategies to all records
226+
# WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that
227+
# create texts based on the requested strategies (e.g. "full record"), which are
228+
# captured in --strategy CLI args
229+
# WIP NOTE: the following simulates that...
230+
# DEBUG ------------------------------------------------------------------------------
231+
import json # noqa: PLC0415
232+
233+
from embeddings.embedding import RecordText # noqa: PLC0415
234+
235+
input_records = (
236+
RecordText(
237+
timdex_record_id=timdex_record["timdex_record_id"],
238+
run_id=timdex_record["run_id"],
239+
run_record_offset=timdex_record["run_record_offset"],
240+
embedding_strategy=_strategy,
241+
text=json.dumps(timdex_record["transformed_record"].decode()),
242+
)
243+
for timdex_record in timdex_records
244+
for _strategy in strategy
245+
)
246+
# DEBUG ------------------------------------------------------------------------------
247+
248+
# create an iterator of Embeddings via the embedding model
249+
# WIP NOTE: this will use the embedding class .create_embeddings() bulk method
250+
# WIP NOTE: the following simulates that...
251+
# DEBUG ------------------------------------------------------------------------------
252+
from embeddings.embedding import Embedding # noqa: PLC0415
253+
254+
embeddings = (
255+
Embedding(
256+
timdex_record_id=input_record.timdex_record_id,
257+
run_id=input_record.run_id,
258+
run_record_offset=input_record.run_record_offset,
259+
embedding_strategy=input_record.embedding_strategy,
260+
model_uri=model.model_uri,
261+
embedding={"coffee": 0.9, "seattle": 0.5},
262+
)
263+
for input_record in input_records
264+
)
265+
# DEBUG ------------------------------------------------------------------------------
266+
267+
# if requested, write embeddings to a local JSONLines file
268+
if output_jsonl:
269+
with jsonlines.open(
270+
output_jsonl,
271+
mode="w",
272+
dumps=lambda obj: json.dumps(
273+
obj,
274+
default=str,
275+
),
276+
) as writer:
277+
for embedding in embeddings:
278+
writer.write(embedding.to_dict())
279+
280+
# else, default writing embeddings back to TIMDEX dataset
281+
else:
282+
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
283+
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
284+
# Embedding instance will nearly 1:1 map to.
285+
raise NotImplementedError
286+
287+
logger.info("Embeddings creation complete.")
155288

156289

157290
if __name__ == "__main__": # pragma: no cover

embeddings/config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ def configure_logger(logger: logging.Logger, *, verbose: bool) -> str:
1111
"%(message)s"
1212
)
1313
logger.setLevel(logging.DEBUG)
14-
for handler in logging.root.handlers:
15-
handler.addFilter(logging.Filter("embeddings"))
1614
else:
1715
logging.basicConfig(
1816
format="%(asctime)s %(levelname)s %(name)s.%(funcName)s(): %(message)s"

embeddings/embedding.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import datetime
2+
import json
3+
from dataclasses import asdict, dataclass, field
4+
5+
6+
@dataclass
7+
class RecordText:
8+
"""Input record for creating an embedding for.
9+
10+
Args:
11+
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
12+
embedding_strategy: strategy used to create text for embedding
13+
text: text to embed, created from the TIMDEX record via the embedding_strategy
14+
"""
15+
16+
timdex_record_id: str
17+
run_id: str
18+
run_record_offset: int
19+
embedding_strategy: str
20+
text: str
21+
22+
23+
@dataclass
24+
class Embedding:
25+
"""Encapsulates a single embedding.
26+
27+
Args:
28+
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
29+
model_uri: model URI used to create the embedding
30+
embedding_strategy: strategy used to create text for embedding
31+
embedding: model embedding created from text
32+
"""
33+
34+
timdex_record_id: str
35+
run_id: str
36+
run_record_offset: int
37+
model_uri: str
38+
embedding_strategy: str
39+
embedding: dict | list[float]
40+
41+
timestamp: datetime.datetime = field(
42+
default_factory=lambda: datetime.datetime.now(datetime.UTC)
43+
)
44+
45+
def to_dict(self) -> dict:
46+
"""Marshal to dictionary."""
47+
return asdict(self)
48+
49+
def to_json(self) -> str:
50+
"""Serialize to JSON."""
51+
return json.dumps(self.to_dict(), default=str)

embeddings/models/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Base class for embedding models."""
22

33
from abc import ABC, abstractmethod
4+
from collections.abc import Iterator
45
from pathlib import Path
56

7+
from embeddings.embedding import Embedding, RecordText
8+
69

710
class BaseEmbeddingModel(ABC):
811
"""Abstract base class for embedding models.
@@ -46,3 +49,22 @@ def download(self) -> Path:
4649
@abstractmethod
4750
def load(self) -> None:
4851
"""Load model from self.model_path."""
52+
53+
@abstractmethod
54+
def create_embedding(self, input_record: RecordText) -> Embedding:
55+
"""Create an Embedding for an RecordText.
56+
57+
Args:
58+
input_record: RecordText instance
59+
"""
60+
61+
def create_embeddings(
62+
self, input_records: Iterator[RecordText]
63+
) -> Iterator[Embedding]:
64+
"""Yield Embeddings for an iterator of InputRecords.
65+
66+
Args:
67+
input_records: iterator of InputRecords
68+
"""
69+
for input_text in input_records:
70+
yield self.create_embedding(input_text)

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from huggingface_hub import snapshot_download
1212
from transformers import AutoModelForMaskedLM, AutoTokenizer
1313

14+
from embeddings.embedding import Embedding, RecordText
1415
from embeddings.models.base import BaseEmbeddingModel
1516

1617
if TYPE_CHECKING:
@@ -161,3 +162,6 @@ def load(self) -> None:
161162
self._id_to_token[token_id] = token
162163

163164
logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")
165+
166+
def create_embedding(self, input_record: RecordText) -> Embedding:
167+
raise NotImplementedError

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ requires-python = ">=3.12"
1010
dependencies = [
1111
"click>=8.2.1",
1212
"huggingface-hub>=0.26.0",
13+
"jsonlines>=4.0.0",
1314
"sentry-sdk>=2.34.1",
1415
"timdex-dataset-api",
1516
"torch>=2.9.0",
@@ -39,6 +40,11 @@ exclude = [
3940
"output/"
4041
]
4142

43+
[[tool.mypy.overrides]]
44+
module = ["timdex_dataset_api.*"]
45+
follow_untyped_imports = true
46+
47+
4248
[tool.pytest.ini_options]
4349
log_level = "INFO"
4450

@@ -88,6 +94,7 @@ fixture-parentheses = false
8894
"tests/**/*" = [
8995
"ANN",
9096
"ARG001",
97+
"PLR2004",
9198
"S101",
9299
]
93100

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pytest
77
from click.testing import CliRunner
88

9+
from embeddings.embedding import Embedding, RecordText
10+
from embeddings.models import registry
911
from embeddings.models.base import BaseEmbeddingModel
1012

1113
logger = logging.getLogger(__name__)
@@ -43,13 +45,30 @@ def download(self) -> Path:
4345
def load(self) -> None:
4446
logger.info("Model loaded successfully, 1.5s")
4547

48+
def create_embedding(self, input_record: RecordText) -> Embedding:
49+
return Embedding(
50+
timdex_record_id=input_record.timdex_record_id,
51+
run_id=input_record.run_id,
52+
run_record_offset=input_record.run_record_offset,
53+
embedding_strategy=input_record.embedding_strategy,
54+
model_uri=self.model_uri,
55+
embedding={"coffee": 0.9, "seattle": 0.5},
56+
)
57+
4658

4759
@pytest.fixture
4860
def mock_model(tmp_path):
4961
"""Fixture providing a MockEmbeddingModel instance."""
5062
return MockEmbeddingModel(tmp_path / "model")
5163

5264

65+
@pytest.fixture
66+
def register_mock_model(monkeypatch):
67+
"""Register MockEmbeddingModel in the model registry."""
68+
monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel)
69+
monkeypatch.setenv("TE_MODEL_PATH", "/fake/path")
70+
71+
5372
@pytest.fixture
5473
def neural_sparse_doc_v3_gte_fake_model_directory(tmp_path):
5574
"""Create a fake downloaded model directory with required files."""

0 commit comments

Comments
 (0)