Skip to content

Commit 39ef93e

Browse files
committed
Lean into model_required CLI decorator
Why these changes are being introduced: Many of the CLI commands will require an embedding class and model to work. A decorator was created originally that injected a --model-uri CLI argument, but it also provides a place to load the class itself and become more of a middleware. How this addresses that need: Updates the model_required decorator to also load the embedding model class. This DRY's up the CLI commands that use it and centralizes that logic and conventions for the CLI argument, env vars, and whatnot. Lastly, it is now required to include a 'model_path' when instantiating a model class instance, and this location is used for both download and load. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-112
1 parent 02264cb commit 39ef93e

File tree

9 files changed

+310
-129
lines changed

9 files changed

+310
-129
lines changed

Dockerfile

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ COPY embeddings ./embeddings
1919
RUN uv pip install --system .
2020

2121
# Download the model and include in the Docker image
22-
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_DOWNLOAD_PATH" are set here to support
23-
# the downloading of the model into this image build, but persist in the container and
24-
# effectively also set this as the default model.
22+
# NOTE: The env vars "TE_MODEL_URI" and "TE_MODEL_PATH" are set here to support
23+
# the downloading of the model during image build, but also persist in the container and
24+
# effectively set the default model.
2525
ENV HF_HUB_DISABLE_PROGRESS_BARS=true
2626
ENV TE_MODEL_URI=opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
27-
ENV TE_MODEL_DOWNLOAD_PATH=/model
27+
ENV TE_MODEL_PATH=/model
2828
RUN python -m embeddings.cli --verbose download-model
2929

3030
ENTRYPOINT ["python", "-m", "embeddings.cli"]

README.md

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ WORKSPACE=### Set to `dev` for local development, this will be set to `stage` an
2424

2525
```shell
2626
TE_MODEL_URI=# HuggingFace model URI
27-
TE_MODEL_DOWNLOAD_PATH=# Download location for model
27+
TE_MODEL_PATH=# Path where the model will be downloaded to and loaded from
2828
HF_HUB_DISABLE_PROGRESS_BARS=#boolean to use progress bars for HuggingFace model downloads; defaults to 'true' in deployed contexts
2929
```
3030

@@ -34,7 +34,7 @@ This CLI application is designed to create embeddings for input texts. To do th
3434

3535
To this end, there is a base embedding class `BaseEmbeddingModel` that is designed to be extended and customized for a particular embedding model.
3636

37-
Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_DOWNLOAD_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.
37+
Once an embedding class has been created, the preferred approach is to set env vars `TE_MODEL_URI` and `TE_MODEL_PATH` directly in the `Dockerfile` to a) download a local snapshot of the model during image build, and b) set this model as the default for the CLI.
3838

3939
This allows invoking the CLI without specifying a model URI or local location, allowing this model to serve as the default, e.g.:
4040

@@ -61,18 +61,38 @@ Usage: embeddings ping [OPTIONS]
6161
```text
6262
Usage: embeddings download-model [OPTIONS]
6363
64-
Download a model from HuggingFace and save as zip file.
64+
Download a model from HuggingFace and save locally.
6565
6666
Options:
67-
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name') [required]
68-
--output PATH Output path for zipped model (e.g., '/path/to/model.zip')
69-
[required]
70-
--help Show this message and exit.
67+
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
68+
[required]
69+
--model-path PATH Path where the model will be downloaded to and loaded
70+
from, e.g. '/path/to/model'. [required]
71+
--help Show this message and exit.
7172
```
7273

73-
### `create-embeddings`
74+
### `test-model-load`
7475
```text
75-
TODO...
76+
Usage: embeddings test-model-load [OPTIONS]
77+
78+
Test loading of embedding class and local model based on env vars.
79+
80+
In a deployed context, the following env vars are expected: -
81+
TE_MODEL_URI - TE_MODEL_PATH
82+
83+
With these set, the embedding class should be registered successfully and
84+
initialized, and the model loaded from a local copy.
85+
86+
This CLI command is NOT used during normal workflows. This is used primary
87+
during development and after model downloading/loading changes to ensure the
88+
model loads correctly.
89+
90+
Options:
91+
--model-uri TEXT HuggingFace model URI (e.g., 'org/model-name')
92+
[required]
93+
--model-path PATH Path where the model will be downloaded to and loaded
94+
from, e.g. '/path/to/model'. [required]
95+
--help Show this message and exit.
7696
```
7797

7898

embeddings/cli.py

Lines changed: 78 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import functools
22
import logging
3-
import os
43
import time
54
from collections.abc import Callable
65
from datetime import timedelta
76
from pathlib import Path
7+
from typing import TYPE_CHECKING
88

99
import click
1010

@@ -13,21 +13,8 @@
1313

1414
logger = logging.getLogger(__name__)
1515

16-
17-
def model_required(f: Callable) -> Callable:
18-
"""Decorator for commands that require a specific model."""
19-
20-
@click.option(
21-
"--model-uri",
22-
envvar="TE_MODEL_URI",
23-
required=True,
24-
help="HuggingFace model URI (e.g., 'org/model-name')",
25-
)
26-
@functools.wraps(f)
27-
def wrapper(*args: list, **kwargs: dict) -> Callable:
28-
return f(*args, **kwargs)
29-
30-
return wrapper
16+
if TYPE_CHECKING:
17+
from embeddings.models.base import BaseEmbeddingModel
3118

3219

3320
@click.group("embeddings")
@@ -60,6 +47,60 @@ def _log_command_elapsed_time() -> None:
6047
ctx.call_on_close(_log_command_elapsed_time)
6148

6249

50+
def model_required(f: Callable) -> Callable:
51+
"""Middleware decorator for commands that require an embedding model.
52+
53+
This decorator adds two CLI options:
54+
- "--model-uri": defaults to environment variable "TE_MODEL_URI"
55+
- "--model-path": defaults to environment variable "TE_MODEL_PATH"
56+
57+
The decorator intercepts these parameters, uses the model URI to identify and
58+
instantiate the appropriate embedding model class with the provided model path,
59+
and stores the model instance in the Click context at ctx.obj["model"].
60+
61+
Both model_uri and model_path parameters are consumed by the decorator and not
62+
passed to the decorated command function.
63+
"""
64+
65+
@click.option(
66+
"--model-uri",
67+
envvar="TE_MODEL_URI",
68+
required=True,
69+
help="HuggingFace model URI (e.g., 'org/model-name')",
70+
)
71+
@click.option(
72+
"--model-path",
73+
required=True,
74+
envvar="TE_MODEL_PATH",
75+
type=click.Path(path_type=Path),
76+
help=(
77+
"Path where the model will be downloaded to and loaded from, "
78+
"e.g. '/path/to/model'."
79+
),
80+
)
81+
@functools.wraps(f)
82+
def wrapper(*args: tuple, **kwargs: dict[str, str | Path]) -> Callable:
83+
# pop "model_uri" and "model_path" from CLI args
84+
model_uri: str = str(kwargs.pop("model_uri"))
85+
model_path: str | Path = str(kwargs.pop("model_path"))
86+
87+
# initialize embedding class
88+
model_class = get_model_class(str(model_uri))
89+
model: BaseEmbeddingModel = model_class(model_path)
90+
logger.info(
91+
f"Embedding class '{model.__class__.__name__}' "
92+
f"initialized from model URI '{model_uri}'."
93+
)
94+
95+
# save embedding class instance to Context
96+
ctx: click.Context = args[0] # type: ignore[assignment]
97+
ctx.obj["model"] = model
98+
99+
return f(*args, **kwargs)
100+
101+
return wrapper
102+
103+
63104
@main.command()
64105
def ping() -> None:
65106
"""Emit 'pong' to debug logs and stdout."""
@@ -68,53 +109,49 @@ def ping() -> None:
68109

69110

70111
@main.command()
112+
@click.pass_context
71113
@model_required
72-
@click.option(
73-
"--output",
74-
required=True,
75-
envvar="TE_MODEL_DOWNLOAD_PATH",
76-
type=click.Path(path_type=Path),
77-
help="Output path for zipped model (e.g., '/path/to/model.zip')",
78-
)
79-
def download_model(model_uri: str, output: Path) -> None:
80-
"""Download a model from HuggingFace and save as zip file."""
81-
# load embedding model class
82-
model_class = get_model_class(model_uri)
83-
model = model_class()
114+
def download_model(
115+
ctx: click.Context,
116+
) -> None:
117+
"""Download a model from HuggingFace and save locally."""
118+
model: BaseEmbeddingModel = ctx.obj["model"]
84119

85-
# download model assets
86-
logger.info(f"Downloading model: {model_uri}")
87-
result_path = model.download(output)
120+
logger.info(f"Downloading model: {model.model_uri}")
121+
result_path = model.download()
88122

89123
message = f"Model downloaded and saved to: {result_path}"
90124
logger.info(message)
91125
click.echo(result_path)
92126

93127

94128
@main.command()
95-
def test_model_load() -> None:
129+
@click.pass_context
130+
@model_required
131+
def test_model_load(ctx: click.Context) -> None:
96132
"""Test loading of embedding class and local model based on env vars.
97133
98134
In a deployed context, the following env vars are expected:
99135
- TE_MODEL_URI
100-
- TE_MODEL_DOWNLOAD_PATH
136+
- TE_MODEL_PATH
101137
102138
With these set, the embedding class should be registered successfully and initialized,
103139
and the model loaded from a local copy.
104-
"""
105-
# load embedding model class
106-
model_class = get_model_class(os.environ["TE_MODEL_URI"])
107-
model = model_class()
108140
109-
model.load(os.environ["TE_MODEL_DOWNLOAD_PATH"])
141+
This CLI command is NOT used during normal workflows. This is used primary
142+
during development and after model downloading/loading changes to ensure the model
143+
loads correctly.
144+
"""
145+
model: BaseEmbeddingModel = ctx.obj["model"]
146+
model.load()
110147
click.echo("OK")
111148

112149

113150
@main.command()
151+
@click.pass_context
114152
@model_required
115-
def create_embeddings(_model_uri: str) -> None:
116-
# TODO: docstring # noqa: FIX002
117-
raise NotImplementedError
153+
def create_embedding(ctx: click.Context) -> None:
154+
"""Create a single embedding for a single input text."""
118155

119156

120157
if __name__ == "__main__": # pragma: no cover

embeddings/models/base.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ class BaseEmbeddingModel(ABC):
1212

1313
MODEL_URI: str # Type hint to document the requirement
1414

15+
def __init__(self, model_path: str | Path) -> None:
16+
"""Initialize the embedding model with a model path.
17+
18+
Args:
19+
model_path: Path where the model will be downloaded to and loaded from.
20+
"""
21+
self.model_path = Path(model_path)
22+
1523
def __init_subclass__(cls, **kwargs: dict) -> None: # noqa: D105
1624
super().__init_subclass__(**kwargs)
1725

@@ -28,17 +36,13 @@ def model_uri(self) -> str:
2836
return self.MODEL_URI
2937

3038
@abstractmethod
31-
def download(self, output_path: str | Path) -> Path:
32-
"""Download and prepare model, saving to output_path.
39+
def download(self) -> Path:
40+
"""Download and prepare model, saving to self.model_path.
3341
34-
Args:
35-
output_path: Path where the model zip should be saved.
42+
Returns:
43+
Path where the model was saved.
3644
"""
3745

3846
@abstractmethod
39-
def load(self, model_path: str | Path) -> None:
40-
"""Load model from local, downloaded instance.
41-
42-
Args:
43-
model_path: Path of local model directory.
44-
"""
47+
def load(self) -> None:
48+
"""Load model from self.model_path."""

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,27 @@ class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
3030

3131
MODEL_URI = "opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
3232

33-
def __init__(self) -> None:
34-
"""Initialize the model."""
35-
super().__init__()
33+
def __init__(self, model_path: str | Path) -> None:
34+
"""Initialize the model.
35+
36+
Args:
37+
model_path: Path where the model will be downloaded to and loaded from.
38+
"""
39+
super().__init__(model_path)
3640
self._model: PreTrainedModel | None = None
3741
self._tokenizer: DistilBertTokenizerFast | None = None
3842
self._special_token_ids: list | None = None
3943
self._id_to_token: list | None = None
4044

41-
def download(self, output_path: str | Path) -> Path:
42-
"""Download and prepare model, saving to output_path.
45+
def download(self) -> Path:
46+
"""Download and prepare model, saving to self.model_path.
4347
44-
Args:
45-
output_path: Path where the model should be saved.
48+
Returns:
49+
Path where the model was saved.
4650
"""
4751
start_time = time.perf_counter()
4852

49-
output_path = Path(output_path)
50-
logger.info(f"Downloading model: {self.model_uri}, saving to: {output_path}.")
53+
logger.info(f"Downloading model: {self.model_uri}, saving to: {self.model_path}.")
5154

5255
with tempfile.TemporaryDirectory() as temp_dir:
5356
temp_path = Path(temp_dir)
@@ -60,19 +63,21 @@ def download(self, output_path: str | Path) -> Path:
6063
self._patch_local_model_with_alibaba_new_impl(temp_path)
6164

6265
# compress model directory as a zip file
63-
if output_path.suffix.lower() == ".zip":
66+
if self.model_path.suffix.lower() == ".zip":
6467
logger.debug("Creating zip file of model contents.")
65-
shutil.make_archive(str(output_path.with_suffix("")), "zip", temp_path)
68+
shutil.make_archive(
69+
str(self.model_path.with_suffix("")), "zip", temp_path
70+
)
6671

6772
# copy to output directory without zipping
6873
else:
69-
logger.debug(f"Copying model contents to {output_path}")
70-
if output_path.exists():
71-
shutil.rmtree(output_path)
72-
shutil.copytree(temp_path, output_path)
74+
logger.debug(f"Copying model contents to {self.model_path}")
75+
if self.model_path.exists():
76+
shutil.rmtree(self.model_path)
77+
shutil.copytree(temp_path, self.model_path)
7378

7479
logger.info(f"Model downloaded successfully, {time.perf_counter() - start_time}s")
75-
return output_path
80+
return self.model_path
7681

7782
def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> None:
7883
"""Patch downloaded model with required assets from Alibaba-NLP/new-impl.
@@ -124,28 +129,23 @@ def _patch_local_model_with_alibaba_new_impl(self, model_temp_path: Path) -> Non
124129

125130
logger.debug("Dependency model Alibaba-NLP/new-impl downloaded and used.")
126131

127-
def load(self, model_path: str | Path) -> None:
128-
"""Load the model from the specified path.
129-
130-
Args:
131-
model_path: Path to the model directory.
132-
"""
132+
def load(self) -> None:
133+
"""Load the model from self.model_path."""
133134
start_time = time.perf_counter()
134-
logger.info(f"Loading model from: {model_path}")
135-
model_path = Path(model_path)
135+
logger.info(f"Loading model from: {self.model_path}")
136136

137137
# ensure model exists locally
138-
if not model_path.exists():
139-
raise FileNotFoundError(f"Model not found at path: {model_path}")
138+
if not self.model_path.exists():
139+
raise FileNotFoundError(f"Model not found at path: {self.model_path}")
140140

141141
# load local model and tokenizer
142142
self._model = AutoModelForMaskedLM.from_pretrained(
143-
model_path,
143+
self.model_path,
144144
trust_remote_code=True,
145145
local_files_only=True,
146146
)
147147
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
148-
model_path,
148+
self.model_path,
149149
local_files_only=True,
150150
)
151151

0 commit comments

Comments
 (0)