Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 73 additions & 10 deletions embeddings/cli.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,91 @@
import logging
import time
from datetime import timedelta
from time import perf_counter
from pathlib import Path

import click

from embeddings.config import configure_logger, configure_sentry
from embeddings.models.registry import get_model_class

logger = logging.getLogger(__name__)


@click.command()
@click.group("embeddings")
@click.option(
"-v", "--verbose", is_flag=True, help="Pass to log at debug level instead of info"
"-v",
"--verbose",
is_flag=True,
help="Pass to log at debug level instead of info",
)
def main(*, verbose: bool) -> None:
start_time = perf_counter()
@click.pass_context
def main(
ctx: click.Context,
*,
verbose: bool,
) -> None:
ctx.ensure_object(dict)
ctx.obj["start_time"] = time.perf_counter()

root_logger = logging.getLogger()
logger.info(configure_logger(root_logger, verbose=verbose))
logger.info(configure_sentry())
logger.info("Running process")

# Do things here!
def _log_command_elapsed_time() -> None:
elapsed_time = time.perf_counter() - ctx.obj["start_time"]
logger.info(
"Total time to complete process: %s", str(timedelta(seconds=elapsed_time))
)

ctx.call_on_close(_log_command_elapsed_time)


@main.command()
def ping() -> None:
"""Emit 'pong' to debug logs and stdout."""
logger.debug("pong")
click.echo("pong")


@main.command()
@click.option(
"--model-uri",
required=True,
help="HuggingFace model URI (e.g., 'org/model-name')",
)
@click.option(
"--output",
required=True,
type=click.Path(path_type=Path),
help="Output path for zipped model (e.g., '/path/to/model.zip')",
)
def download_model(model_uri: str, output: Path) -> None:
"""Download a model from HuggingFace and save as zip file."""
# load embedding model class
model_class = get_model_class(model_uri)
model = model_class(model_uri)

# download model assets
logger.info(f"Downloading model: {model_uri}")
result_path = model.download(output)

message = f"Model downloaded and saved to: {result_path}"
logger.info(message)
click.echo(result_path)


@main.command()
@click.option(
"--model-uri",
required=True,
help="HuggingFace model URI (e.g., 'org/model-name')",
)
def create_embeddings(_model_uri: str) -> None:
# TODO: docstring # noqa: FIX002
raise NotImplementedError


elapsed_time = perf_counter() - start_time
logger.info(
"Total time to complete process: %s", str(timedelta(seconds=elapsed_time))
)
if __name__ == "__main__": # pragma: no cover
logger = logging.getLogger("embeddings.main")
main()
1 change: 1 addition & 0 deletions embeddings/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

23 changes: 23 additions & 0 deletions embeddings/models/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""Base class for embedding models."""

from abc import ABC, abstractmethod
from pathlib import Path


class BaseEmbeddingModel(ABC):
"""Abstract base class for embedding models.

Args:
model_uri: HuggingFace model identifier (e.g., 'org/model-name').
"""

def __init__(self, model_uri: str) -> None:
self.model_uri = model_uri

@abstractmethod
def download(self, output_path: Path) -> Path:
"""Download and prepare model, saving to output_path.

Args:
output_path: Path where the model zip should be saved.
"""
24 changes: 24 additions & 0 deletions embeddings/models/os_neural_sparse_doc_v3_gte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""OpenSearch Neural Sparse Doc v3 GTE model."""

import logging
from pathlib import Path

from embeddings.models.base import BaseEmbeddingModel

logger = logging.getLogger(__name__)


class OSNeuralSparseDocV3GTE(BaseEmbeddingModel):
"""OpenSearch Neural Sparse Encoding Doc v3 GTE model.

HuggingFace URI: opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte
"""

def download(self, output_path: Path) -> Path:
"""Download and prepare model, saving to output_path.

Args:
output_path: Path where the model zip should be saved.
"""
logger.info(f"Downloading model: { self.model_uri}, saving to: {output_path}.")
raise NotImplementedError
31 changes: 31 additions & 0 deletions embeddings/models/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Registry mapping model URIs to model classes."""

import logging

from embeddings.models.base import BaseEmbeddingModel
from embeddings.models.os_neural_sparse_doc_v3_gte import OSNeuralSparseDocV3GTE

logger = logging.getLogger(__name__)

MODEL_REGISTRY: dict[str, type[BaseEmbeddingModel]] = {
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte": (
OSNeuralSparseDocV3GTE
),
}


def get_model_class(model_uri: str) -> type[BaseEmbeddingModel]:
"""Get model class for given URI.

Args:
model_uri: HuggingFace model identifier.

Returns:
Model class for the given URI.
"""
if model_uri not in MODEL_REGISTRY:
available = ", ".join(sorted(MODEL_REGISTRY.keys()))
msg = f"Unknown model URI: {model_uri}. Available models: {available}"
logger.error(msg)
raise ValueError(msg)
return MODEL_REGISTRY[model_uri]
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ requires-python = ">=3.12"

dependencies = [
"click>=8.2.1",
"huggingface-hub>=0.26.0",
"sentry-sdk>=2.34.1",
"timdex-dataset-api",
]

[dependency-groups]
Expand Down Expand Up @@ -55,11 +57,14 @@ ignore = [
"D101",
"D102",
"D103",
"D104",
"D104",
"G004",
"PLR0912",
"PLR0913",
"PLR0915",
"S321",
"TD002",
"TD003",
]

# allow autofix behavior for specified rules
Expand All @@ -84,9 +89,12 @@ max-doc-length = 90
[tool.ruff.lint.pydocstyle]
convention = "google"

[tool.uv.sources]
timdex-dataset-api = { git = "https://github.com/MITLibraries/timdex-dataset-api" }

[project.scripts]
embeddings = "embeddings.cli:main"

[build-system]
requires = ["setuptools>=61"]
build-backend = "setuptools.build_meta"
build-backend = "setuptools.build_meta"
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import zipfile
from pathlib import Path

import pytest
from click.testing import CliRunner

from embeddings.models.base import BaseEmbeddingModel


@pytest.fixture(autouse=True)
def _test_env(monkeypatch):
Expand All @@ -11,3 +16,22 @@ def _test_env(monkeypatch):
@pytest.fixture
def runner():
return CliRunner()


class MockEmbeddingModel(BaseEmbeddingModel):
"""Simple test model that doesn't hit external APIs."""

def download(self, output_path: Path) -> Path:
"""Create a fake model zip file for testing."""
output_path.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(output_path, "w") as zf:
zf.writestr("config.json", '{"model": "mock", "vocab_size": 30000}')
zf.writestr("pytorch_model.bin", b"fake model weights")
zf.writestr("tokenizer.json", '{"version": "1.0"}')
return output_path


@pytest.fixture
def mock_model():
"""Fixture providing a MockEmbeddingModel instance."""
return MockEmbeddingModel("test/mock-model")
42 changes: 34 additions & 8 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,43 @@
from embeddings.cli import main


def test_cli_no_options(caplog, runner):
result = runner.invoke(main)
def test_cli_default_logging(caplog, runner):
result = runner.invoke(main, ["ping"])
assert result.exit_code == 0
assert "Logger 'root' configured with level=INFO" in caplog.text
assert "Running process" in caplog.text
assert "Total time to complete process" in caplog.text


def test_cli_all_options(caplog, runner):
result = runner.invoke(main, ["--verbose"])
def test_cli_debug_logging(caplog, runner):
with caplog.at_level("DEBUG"):
result = runner.invoke(main, ["--verbose", "ping"])
assert result.exit_code == 0
assert "Logger 'root' configured with level=DEBUG" in caplog.text
assert "Running process" in caplog.text
assert "Total time to complete process" in caplog.text
assert "pong" in caplog.text
assert "pong" in result.output


def test_download_model_unknown_uri(caplog, runner):
result = runner.invoke(
main, ["download-model", "--model-uri", "unknown/model", "--output", "out.zip"]
)
assert result.exit_code != 0
assert "Unknown model URI" in caplog.text


def test_download_model_not_implemented(caplog, runner):
caplog.set_level("INFO")
result = runner.invoke(
main,
[
"download-model",
"--model-uri",
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte",
"--output",
"out.zip",
],
)
assert (
"Downloading model: opensearch-project/"
"opensearch-neural-sparse-encoding-doc-v3-gte, saving to: out.zip."
) in caplog.text
assert result.exit_code != 0
48 changes: 48 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import zipfile

import pytest

from embeddings.models.registry import MODEL_REGISTRY, get_model_class


def test_mock_model_instantiation(mock_model):
assert mock_model.model_uri == "test/mock-model"


def test_mock_model_download_creates_zip(mock_model, tmp_path):
output_path = tmp_path / "test_model.zip"
result = mock_model.download(output_path)

assert result == output_path
assert output_path.exists()
assert zipfile.is_zipfile(output_path)


def test_mock_model_download_contains_expected_files(mock_model, tmp_path):
output_path = tmp_path / "test_model.zip"
mock_model.download(output_path)

with zipfile.ZipFile(output_path, "r") as zf:
file_list = zf.namelist()
assert "config.json" in file_list
assert "pytorch_model.bin" in file_list
assert "tokenizer.json" in file_list


def test_registry_contains_opensearch_model():
assert (
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
in MODEL_REGISTRY
)


def test_get_model_class_returns_correct_class():
model_class = get_model_class(
"opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte"
)
assert model_class.__name__ == "OSNeuralSparseDocV3GTE"


def test_get_model_class_raises_for_unknown_uri():
with pytest.raises(ValueError, match="Unknown model URI"):
get_model_class("unknown/model-uri")
Loading