|
7 | 7 | from typing import TYPE_CHECKING |
8 | 8 |
|
9 | 9 | import click |
| 10 | +import jsonlines |
| 11 | +from timdex_dataset_api import TIMDEXDataset |
10 | 12 |
|
11 | 13 | from embeddings.config import configure_logger, configure_sentry |
12 | 14 | from embeddings.models.registry import get_model_class |
@@ -150,8 +152,140 @@ def test_model_load(ctx: click.Context) -> None: |
150 | 152 | @main.command() |
151 | 153 | @click.pass_context |
152 | 154 | @model_required |
153 | | -def create_embedding(ctx: click.Context) -> None: |
154 | | - """Create a single embedding for a single input text.""" |
| 155 | +@click.option( |
| 156 | + "-d", |
| 157 | + "--dataset-location", |
| 158 | + required=True, |
| 159 | + type=click.Path(), |
| 160 | + help="TIMDEX dataset location, e.g. 's3://timdex/dataset', to read records from.", |
| 161 | +) |
| 162 | +@click.option( |
| 163 | + "--run-id", |
| 164 | + required=True, |
| 165 | + type=str, |
| 166 | + help="TIMDEX ETL run id.", |
| 167 | +) |
| 168 | +@click.option( |
| 169 | + "--run-record-offset", |
| 170 | + required=True, |
| 171 | + type=int, |
| 172 | + default=0, |
| 173 | + help="TIMDEX ETL run record offset to start from, default = 0.", |
| 174 | +) |
| 175 | +@click.option( |
| 176 | + "--record-limit", |
| 177 | + required=True, |
| 178 | + type=int, |
| 179 | + default=None, |
| 180 | + help="Limit number of records after --run-record-offset, default = None (unlimited).", |
| 181 | +) |
| 182 | +@click.option( |
| 183 | + "--strategy", |
| 184 | + type=str, # WIP: establish an enum of supported strategies |
| 185 | + required=True, |
| 186 | + multiple=True, |
| 187 | + help="Pre-embedding record transformation strategy to use. Repeatable.", |
| 188 | +) |
| 189 | +@click.option( |
| 190 | + "--output-jsonl", |
| 191 | + required=False, |
| 192 | + type=str, |
| 193 | + default=None, |
| 194 | + help="Optionally write embeddings to local JSONLines file (primarily for testing).", |
| 195 | +) |
| 196 | +def create_embeddings( |
| 197 | + ctx: click.Context, |
| 198 | + dataset_location: str, |
| 199 | + run_id: str, |
| 200 | + run_record_offset: int, |
| 201 | + record_limit: int, |
| 202 | + strategy: list[str], |
| 203 | + output_jsonl: str, |
| 204 | +) -> None: |
| 205 | + """Create embeddings for TIMDEX records.""" |
| 206 | + model: BaseEmbeddingModel = ctx.obj["model"] |
| 207 | + |
| 208 | + # init TIMDEXDataset |
| 209 | + timdex_dataset = TIMDEXDataset(dataset_location) |
| 210 | + |
| 211 | + # query TIMDEX dataset for an iterator of records |
| 212 | + timdex_records = timdex_dataset.read_dicts_iter( |
| 213 | + columns=[ |
| 214 | + "timdex_record_id", |
| 215 | + "run_id", |
| 216 | + "run_record_offset", |
| 217 | + "transformed_record", |
| 218 | + ], |
| 219 | + run_id=run_id, |
| 220 | + where=f"""run_record_offset >= {run_record_offset}""", |
| 221 | + limit=record_limit, |
| 222 | + action="index", |
| 223 | + ) |
| 224 | + |
| 225 | + # create an iterator of InputTexts applying all requested strategies to all records |
| 226 | + # WIP NOTE: this will leverage some kind of pre-embedding transformer class(es) that |
| 227 | + # create texts based on the requested strategies (e.g. "full record"), which are |
| 228 | + # captured in --strategy CLI args |
| 229 | + # WIP NOTE: the following simulates that... |
| 230 | + # DEBUG ------------------------------------------------------------------------------ |
| 231 | + import json # noqa: PLC0415 |
| 232 | + |
| 233 | + from embeddings.embedding import EmbeddingInput # noqa: PLC0415 |
| 234 | + |
| 235 | + input_records = ( |
| 236 | + EmbeddingInput( |
| 237 | + timdex_record_id=timdex_record["timdex_record_id"], |
| 238 | + run_id=timdex_record["run_id"], |
| 239 | + run_record_offset=timdex_record["run_record_offset"], |
| 240 | + embedding_strategy=_strategy, |
| 241 | + text=json.dumps(timdex_record["transformed_record"].decode()), |
| 242 | + ) |
| 243 | + for timdex_record in timdex_records |
| 244 | + for _strategy in strategy |
| 245 | + ) |
| 246 | + # DEBUG ------------------------------------------------------------------------------ |
| 247 | + |
| 248 | + # create an iterator of Embeddings via the embedding model |
| 249 | + # WIP NOTE: this will use the embedding class .create_embeddings() bulk method |
| 250 | + # WIP NOTE: the following simulates that... |
| 251 | + # DEBUG ------------------------------------------------------------------------------ |
| 252 | + from embeddings.embedding import Embedding # noqa: PLC0415 |
| 253 | + |
| 254 | + embeddings = ( |
| 255 | + Embedding( |
| 256 | + timdex_record_id=input_record.timdex_record_id, |
| 257 | + run_id=input_record.run_id, |
| 258 | + run_record_offset=input_record.run_record_offset, |
| 259 | + embedding_strategy=input_record.embedding_strategy, |
| 260 | + model_uri=model.model_uri, |
| 261 | + embedding_vector=[0.1, 0.2, 0.3], |
| 262 | + embedding_token_weights={"coffee": 0.9, "seattle": 0.5}, |
| 263 | + ) |
| 264 | + for input_record in input_records |
| 265 | + ) |
| 266 | + # DEBUG ------------------------------------------------------------------------------ |
| 267 | + |
| 268 | + # if requested, write embeddings to a local JSONLines file |
| 269 | + if output_jsonl: |
| 270 | + with jsonlines.open( |
| 271 | + output_jsonl, |
| 272 | + mode="w", |
| 273 | + dumps=lambda obj: json.dumps( |
| 274 | + obj, |
| 275 | + default=str, |
| 276 | + ), |
| 277 | + ) as writer: |
| 278 | + for embedding in embeddings: |
| 279 | + writer.write(embedding.to_dict()) |
| 280 | + |
| 281 | + # else, default writing embeddings back to TIMDEX dataset |
| 282 | + else: |
| 283 | + # WIP NOTE: write via anticipated timdex_dataset.embeddings.write(...) |
| 284 | + # NOTE: will likely use an imported TIMDEXEmbedding class from TDA, which the |
| 285 | + # Embedding instance will nearly 1:1 map to. |
| 286 | + raise NotImplementedError |
| 287 | + |
| 288 | + logger.info("Embeddings creation complete.") |
155 | 289 |
|
156 | 290 |
|
157 | 291 | if __name__ == "__main__": # pragma: no cover |
|
0 commit comments