Skip to content

Commit bd9de2b

Browse files
authored
Merge pull request #16 from MITLibraries/USE-112-scaffold-creating-embeddings
Stub CLI command methods to create embeddings
2 parents b3dbbb8 + d4e6e54 commit bd9de2b

File tree

14 files changed

+400
-36
lines changed

14 files changed

+400
-36
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,5 @@ cython_debug/
155155
.DS_Store
156156
output/
157157
.vscode/
158+
159+
CLAUDE.md

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,28 @@ Options:
9595
--help Show this message and exit.
9696
```
9797

98+
### `create-embeddings`
99+
```text
100+
Usage: embeddings create-embeddings [OPTIONS]
101+
102+
Create embeddings for TIMDEX records.
98103
104+
Options:
105+
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
106+
[required]
107+
--model-path PATH Path where the model will be downloaded to and
108+
loaded from, e.g. '/path/to/model'. [required]
109+
-d, --dataset-location PATH TIMDEX dataset location, e.g.
110+
's3://timdex/dataset', to read records from.
111+
[required]
112+
--run-id TEXT TIMDEX ETL run id. [required]
113+
--run-record-offset INTEGER TIMDEX ETL run record offset to start from,
114+
default = 0. [required]
115+
--record-limit INTEGER Limit number of records after --run-record-
116+
offset, default = None (unlimited). [required]
117+
--strategy TEXT Pre-embedding record transformation strategy to
118+
use. Repeatable. [required]
119+
--output-jsonl TEXT Optionally write embeddings to local JSONLines
120+
file (primarily for testing).
121+
--help Show this message and exit.
122+
```

embeddings/cli.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import TYPE_CHECKING
88

99
import click
10+
import jsonlines
11+
from timdex_dataset_api import TIMDEXDataset
1012

1113
from embeddings.config import configure_logger, configure_sentry
1214
from embeddings.models.registry import get_model_class
@@ -150,8 +152,140 @@ def test_model_load(ctx: click.Context) -> None:
150152
@main.command()
151153
@click.pass_context
152154
@model_required
153-
def create_embedding(ctx: click.Context) -> None:
154-
"""Create a single embedding for a single input text."""
155+
@click.option(
156+
"-d",
157+
"--dataset-location",
158+
required=True,
159+
type=click.Path(),
160+
help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.",
161+
)
162+
@click.option(
163+
"--run-id",
164+
required=True,
165+
type=str,
166+
help="TIMDEX ETL run id.",
167+
)
168+
@click.option(
169+
"--run-record-offset",
170+
required=True,
171+
type=int,
172+
default=0,
173+
help="TIMDEX ETL run record offset to start from, default = 0.",
174+
)
175+
@click.option(
176+
"--record-limit",
177+
required=True,
178+
type=int,
179+
default=None,
180+
help="Limit number of records after --run-record-offset, default = None (unlimited).",
181+
)
182+
@click.option(
183+
"--strategy",
184+
type=str, # WIP: establish an enum of supported strategies
185+
required=True,
186+
multiple=True,
187+
help="Pre-embedding record transformation strategy to use. Repeatable.",
188+
)
189+
@click.option(
190+
"--output-jsonl",
191+
required=False,
192+
type=str,
193+
default=None,
194+
help="Optionally write embeddings to local JSONLines file (primarily for testing).",
195+
)
196+
def create_embeddings(
197+
ctx: click.Context,
198+
dataset_location: str,
199+
run_id: str,
200+
run_record_offset: int,
201+
record_limit: int,
202+
strategy: list[str],
203+
output_jsonl: str,
204+
) -> None:
205+
"""Create embeddings for TIMDEX records."""
206+
model: BaseEmbeddingModel = ctx.obj["model"]
207+
208+
# init TIMDEXDataset
209+
timdex_dataset = TIMDEXDataset(dataset_location)
210+
211+
# query TIMDEX dataset for an iterator of records
212+
timdex_records = timdex_dataset.read_dicts_iter(
213+
columns=[
214+
"timdex_record_id",
215+
"run_id",
216+
"run_record_offset",
217+
"transformed_record",
218+
],
219+
run_id=run_id,
220+
where=f"""run_record_offset >= {run_record_offset}""",
221+
limit=record_limit,
222+
action="index",
223+
)
224+
225+
# create an iterator of InputTexts applying all requested strategies to all records
226+
# WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that
227+
# create texts based on the requested strategies (e.g. "full record"), which are
228+
# captured in --strategy CLI args
229+
# WIP NOTE: the following simulates that...
230+
# DEBUG ------------------------------------------------------------------------------
231+
import json # noqa: PLC0415
232+
233+
from embeddings.embedding import EmbeddingInput # noqa: PLC0415
234+
235+
input_records = (
236+
EmbeddingInput(
237+
timdex_record_id=timdex_record["timdex_record_id"],
238+
run_id=timdex_record["run_id"],
239+
run_record_offset=timdex_record["run_record_offset"],
240+
embedding_strategy=_strategy,
241+
text=json.dumps(timdex_record["transformed_record"].decode()),
242+
)
243+
for timdex_record in timdex_records
244+
for _strategy in strategy
245+
)
246+
# DEBUG ------------------------------------------------------------------------------
247+
248+
# create an iterator of Embeddings via the embedding model
249+
# WIP NOTE: this will use the embedding class .create_embeddings() bulk method
250+
# WIP NOTE: the following simulates that...
251+
# DEBUG ------------------------------------------------------------------------------
252+
from embeddings.embedding import Embedding # noqa: PLC0415
253+
254+
embeddings = (
255+
Embedding(
256+
timdex_record_id=input_record.timdex_record_id,
257+
run_id=input_record.run_id,
258+
run_record_offset=input_record.run_record_offset,
259+
embedding_strategy=input_record.embedding_strategy,
260+
model_uri=model.model_uri,
261+
embedding_vector=[0.1, 0.2, 0.3],
262+
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
263+
)
264+
for input_record in input_records
265+
)
266+
# DEBUG ------------------------------------------------------------------------------
267+
268+
# if requested, write embeddings to a local JSONLines file
269+
if output_jsonl:
270+
with jsonlines.open(
271+
output_jsonl,
272+
mode="w",
273+
dumps=lambda obj: json.dumps(
274+
obj,
275+
default=str,
276+
),
277+
) as writer:
278+
for embedding in embeddings:
279+
writer.write(embedding.to_dict())
280+
281+
# else, default writing embeddings back to TIMDEX dataset
282+
else:
283+
# WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...)
284+
# NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the
285+
# Embedding instance will nearly 1:1 map to.
286+
raise NotImplementedError
287+
288+
logger.info("Embeddings creation complete.")
155289

156290

157291
if __name__ == "__main__": # pragma: no cover

embeddings/config.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,8 @@ def configure_logger(logger: logging.Logger, *, verbose: bool) -> str:
1010
format="%(asctime)s %(levelname)s %(name)s.%(funcName)s() line %(lineno)d: "
1111
"%(message)s"
1212
)
13-
logger.setLevel(logging.DEBUG)
14-
for handler in logging.root.handlers:
15-
handler.addFilter(logging.Filter("embeddings"))
13+
logging.getLogger("embeddings").setLevel(logging.DEBUG)
14+
logging.getLogger("timdex_dataset_api").setLevel(logging.DEBUG)
1615
else:
1716
logging.basicConfig(
1817
format="%(asctime)s %(levelname)s %(name)s.%(funcName)s(): %(message)s"

embeddings/embedding.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import datetime
2+
import json
3+
from dataclasses import asdict, dataclass, field
4+
5+
6+
@dataclass
7+
class EmbeddingInput:
8+
"""Encapsulates the inputs for an embedding.
9+
10+
When creating an embedding, we need to note what TIMDEX record the embedding is
11+
associated with and what strategy was used to prepare the embedding input text from
12+
the record itself.
13+
14+
Args:
15+
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
16+
embedding_strategy: strategy used to create text for embedding
17+
text: text to embed, created from the TIMDEX record via the embedding_strategy
18+
"""
19+
20+
timdex_record_id: str
21+
run_id: str
22+
run_record_offset: int
23+
embedding_strategy: str
24+
text: str
25+
26+
27+
@dataclass
28+
class Embedding:
29+
"""Encapsulates a single embedding.
30+
31+
Args:
32+
(timdex_record_id, run_id, run_record_offset): composite key for TIMDEX record
33+
model_uri: model URI used to create the embedding
34+
embedding_strategy: strategy used to create text for embedding
35+
embedding_vector: vector representation of embedding
36+
embedding_token_weights: decoded token:weight pairs from sparse vector
37+
- only applicable to models that produce this output
38+
"""
39+
40+
timdex_record_id: str
41+
run_id: str
42+
run_record_offset: int
43+
model_uri: str
44+
embedding_strategy: str
45+
embedding_vector: list[float]
46+
embedding_token_weights: dict
47+
48+
timestamp: datetime.datetime = field(
49+
default_factory=lambda: datetime.datetime.now(datetime.UTC)
50+
)
51+
52+
def to_dict(self) -> dict:
53+
"""Marshal to dictionary."""
54+
return asdict(self)
55+
56+
def to_json(self) -> str:
57+
"""Serialize to JSON."""
58+
return json.dumps(self.to_dict(), default=str)

embeddings/models/base.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
"""Base class for embedding models."""
22

33
from abc import ABC, abstractmethod
4+
from collections.abc import Iterator
45
from pathlib import Path
56

7+
from embeddings.embedding import Embedding, EmbeddingInput
8+
69

710
class BaseEmbeddingModel(ABC):
811
"""Abstract base class for embedding models.
@@ -46,3 +49,22 @@ def download(self) -> Path:
4649
@abstractmethod
4750
def load(self) -> None:
4851
"""Load model from self.model_path."""
52+
53+
@abstractmethod
54+
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
55+
"""Create an Embedding for an EmbeddingInput.
56+
57+
Args:
58+
input_record: EmbeddingInput instance
59+
"""
60+
61+
def create_embeddings(
62+
self, input_records: Iterator[EmbeddingInput]
63+
) -> Iterator[Embedding]:
64+
"""Yield Embeddings for an iterator of InputRecords.
65+
66+
Args:
67+
input_records: iterator of InputRecords
68+
"""
69+
for input_text in input_records:
70+
yield self.create_embedding(input_text)

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from huggingface_hub import snapshot_download
1212
from transformers import AutoModelForMaskedLM, AutoTokenizer
1313

14+
from embeddings.embedding import Embedding, EmbeddingInput
1415
from embeddings.models.base import BaseEmbeddingModel
1516

1617
if TYPE_CHECKING:
@@ -161,3 +162,6 @@ def load(self) -> None:
161162
self._id_to_token[token_id] = token
162163

163164
logger.info(f"Model loaded successfully, {time.perf_counter()-start_time}s")
165+
166+
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
167+
raise NotImplementedError

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ requires-python = ">=3.12"
1010
dependencies = [
1111
"click>=8.2.1",
1212
"huggingface-hub>=0.26.0",
13+
"jsonlines>=4.0.0",
1314
"sentry-sdk>=2.34.1",
1415
"timdex-dataset-api",
1516
"torch>=2.9.0",
@@ -39,6 +40,11 @@ exclude = [
3940
"output/"
4041
]
4142

43+
[[tool.mypy.overrides]]
44+
module = ["timdex_dataset_api.*"]
45+
follow_untyped_imports = true
46+
47+
4248
[tool.pytest.ini_options]
4349
log_level = "INFO"
4450

@@ -88,6 +94,7 @@ fixture-parentheses = false
8894
"tests/**/*" = [
8995
"ANN",
9096
"ARG001",
97+
"PLR2004",
9198
"S101",
9299
]
93100

tests/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import pytest
77
from click.testing import CliRunner
88

9+
from embeddings.embedding import Embedding, EmbeddingInput
10+
from embeddings.models import registry
911
from embeddings.models.base import BaseEmbeddingModel
1012

1113
logger = logging.getLogger(__name__)
@@ -43,13 +45,31 @@ def download(self) -> Path:
4345
def load(self) -> None:
4446
logger.info("Model loaded successfully, 1.5s")
4547

48+
def create_embedding(self, input_record: EmbeddingInput) -> Embedding:
49+
return Embedding(
50+
timdex_record_id=input_record.timdex_record_id,
51+
run_id=input_record.run_id,
52+
run_record_offset=input_record.run_record_offset,
53+
embedding_strategy=input_record.embedding_strategy,
54+
model_uri=self.model_uri,
55+
embedding_vector=[0.1, 0.2, 0.3],
56+
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
57+
)
58+
4659

4760
@pytest.fixture
4861
def mock_model(tmp_path):
4962
"""Fixture providing a MockEmbeddingModel instance."""
5063
return MockEmbeddingModel(tmp_path / "model")
5164

5265

66+
@pytest.fixture
67+
def register_mock_model(monkeypatch):
68+
"""Register MockEmbeddingModel in the model registry."""
69+
monkeypatch.setitem(registry.MODEL_REGISTRY, "test/mock-model", MockEmbeddingModel)
70+
monkeypatch.setenv("TE_MODEL_PATH", "/fake/path")
71+
72+
5373
@pytest.fixture
5474
def neural_sparse_doc_v3_gte_fake_model_directory(tmp_path):
5575
"""Create a fake downloaded model directory with required files."""

0 commit comments

Comments
 (0)