generated from MinishLab/watertemplate
-
Notifications
You must be signed in to change notification settings - Fork 10
Add hub integration #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Pringled
merged 16 commits into
MinishLab:main
from
davidberenstein1957:add-hub-integration
Feb 28, 2025
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
afc3fb4
Refactor README and Vicinity class to support any serializable item type
davidberenstein1957 9ffb491
Update README.md to include examples for saving/loading vector stores…
davidberenstein1957 7b2bb53
Refactor Vicinity class to streamline token handling
davidberenstein1957 a5ce987
Refactor item handling in tests and Vicinity class
davidberenstein1957 022c7b1
Apply suggestions from code review
davidberenstein1957 eaabbfa
Refactor token insertion in Vicinity class to simplify duplicate hand…
davidberenstein1957 031c136
Refactor token deletion logic in Vicinity class to improve error hand…
davidberenstein1957 26e7ed6
Enhance error handling in Vicinity class for JSON serialization
davidberenstein1957 6fb6305
Add non-serializable items fixture and test for Vicinity class
davidberenstein1957 c86f7e5
Add Hugging Face integration for Vicinity class
davidberenstein1957 a410686
Merge branch 'MinishLab:main' into add-hub-integration
davidberenstein1957 4f30d45
Enhance Hugging Face integration with improved error handling and dat…
davidberenstein1957 cab15e5
Update pyproject.toml and README.md for improved package installation…
davidberenstein1957 65465f3
Add test for Vicinity.load_from_hub method
davidberenstein1957 06545dd
Remove test files for utils and vicinity modules
davidberenstein1957 cc3fbf4
Add comprehensive test suites for Vicinity and utility functions
davidberenstein1957 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import io | ||
| import sys | ||
|
|
||
| from vicinity import Vicinity | ||
| from vicinity.datatypes import Backend | ||
| from vicinity.integrations.huggingface import _MODEL_NAME_OR_PATH_PRINT_STATEMENT | ||
|
|
||
| BackendType = tuple[Backend, str] | ||
|
|
||
|
|
||
| def test_load_from_hub(vicinity_instance: Vicinity) -> None: | ||
| """ | ||
| Test Vicinity.load_from_hub. | ||
|
|
||
| :param vicinity_instance: A Vicinity instance. | ||
| """ | ||
| repo_id = "davidberenstein1957/my-vicinity-repo" | ||
| # get the first part of the print statement to test if model name or path is printed | ||
| expected_print_statement = _MODEL_NAME_OR_PATH_PRINT_STATEMENT.split(":")[0] | ||
|
|
||
| # Capture the output | ||
| captured_output = io.StringIO() | ||
| sys.stdout = captured_output | ||
|
|
||
| Vicinity.load_from_hub(repo_id=repo_id) | ||
|
|
||
| # Reset redirect. | ||
| sys.stdout = sys.__stdout__ | ||
|
|
||
| # Check if the expected message is in the output | ||
| assert expected_print_statement in captured_output.getvalue() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| --- | ||
| tags: | ||
| - vicinity | ||
| - vector-store | ||
| --- | ||
|
|
||
| # Dataset Card for {repo_id} | ||
|
|
||
| This dataset was created using the [vicinity](https://github.com/MinishLab/vicinity) library, a lightweight nearest neighbors library with flexible backends. | ||
|
|
||
| It contains a vector space with {num_items} items. | ||
|
|
||
| ## Usage | ||
|
|
||
| You can load this dataset using the following code: | ||
|
|
||
| ```python | ||
| from vicinity import Vicinity | ||
| vicinity = Vicinity.load_from_hub("{repo_id}") | ||
| ``` | ||
|
|
||
| After loading the dataset, you can use the `vicinity.query` method to find the nearest neighbors to a vector. | ||
|
|
||
| ## Configuration | ||
|
|
||
| The configuration of the dataset is stored in the `config.json` file. The vector backend is stored in the `backend` folder. | ||
|
|
||
| ```bash | ||
| {config} | ||
| ``` |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import logging | ||
| import tempfile | ||
| from pathlib import Path | ||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from vicinity.backends import BasicVectorStore, get_backend_class | ||
| from vicinity.datatypes import Backend | ||
|
|
||
| if TYPE_CHECKING: | ||
| from huggingface_hub import CommitInfo | ||
|
|
||
| from vicinity.vicinity import Vicinity | ||
|
|
||
| _HUB_IMPORT_ERROR = ImportError( | ||
| "`datasets` and `huggingface_hub` are required to push to the Hugging Face Hub. Please install them with `pip install 'vicinity[huggingface]'`" | ||
| ) | ||
| _MODEL_NAME_OR_PATH_PRINT_STATEMENT = ( | ||
| "Embeddings in Vicinity instance were created from model name or path: {model_name_or_path}" | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class HuggingFaceMixin: | ||
| def push_to_hub( | ||
| self, | ||
| model_name_or_path: str, | ||
| repo_id: str, | ||
| token: str | None = None, | ||
| private: bool = False, | ||
| **kwargs: Any, | ||
| ) -> "CommitInfo": | ||
| """ | ||
| Push the Vicinity instance to the Hugging Face Hub. | ||
|
|
||
| :param model_name_or_path: The name of the model or the path to the local directory | ||
| that was used to create the embeddings in the Vicinity instance. | ||
| :param repo_id: The repository ID on the Hugging Face Hub | ||
| :param token: Optional authentication token for private repositories | ||
| :param private: Whether to create a private repository | ||
| :param **kwargs: Additional arguments passed to Dataset.push_to_hub() | ||
| :return: The commit info | ||
| """ | ||
| try: | ||
| from datasets import Dataset | ||
| from huggingface_hub import DatasetCard, upload_file, upload_folder | ||
| except ImportError: | ||
| raise _HUB_IMPORT_ERROR | ||
|
|
||
| # Create and push dataset with items and vectors | ||
| if isinstance(self.items[0], dict): | ||
| dataset_dict = {k: [item[k] for item in self.items] for k in self.items[0].keys()} | ||
| else: | ||
| dataset_dict = {"items": self.items} | ||
| if self.vector_store is not None: | ||
| dataset_dict["vectors"] = self.vector_store.vectors | ||
| dataset = Dataset.from_dict(dataset_dict) | ||
| dataset.push_to_hub(repo_id, token=token, private=private, **kwargs) | ||
|
|
||
| # Save backend and config files to temp directory and upload | ||
| with tempfile.TemporaryDirectory() as temp_dir: | ||
| temp_path = Path(temp_dir) | ||
|
|
||
| # Save and upload backend | ||
| self.backend.save(temp_path) | ||
| upload_folder( | ||
| repo_id=repo_id, | ||
| folder_path=temp_path, | ||
| token=token, | ||
| repo_type="dataset", | ||
| path_in_repo="backend", | ||
| ) | ||
|
|
||
| # Save and upload config | ||
| config = { | ||
| "metadata": self.metadata, | ||
| "backend_type": self.backend.backend_type.value, | ||
| "model_name_or_path": model_name_or_path, | ||
| } | ||
| config_path = temp_path / "config.json" | ||
| config_path.write_text(json.dumps(config)) | ||
| upload_file( | ||
| repo_id=repo_id, | ||
| path_or_fileobj=config_path, | ||
| token=token, | ||
| repo_type="dataset", | ||
| path_in_repo="config.json", | ||
| ) | ||
|
|
||
| # Load the dataset card template from the related path | ||
| template_path = Path(__file__).parent / "dataset_card_template.md" | ||
| template = template_path.read_text() | ||
| content = template.format(repo_id=repo_id, num_items=len(self.items), config=json.dumps(config, indent=4)) | ||
| return DatasetCard(content=content).push_to_hub(repo_id=repo_id, token=token, repo_type="dataset") | ||
|
|
||
| @classmethod | ||
| def load_from_hub(cls, repo_id: str, token: str | None = None, **kwargs: Any) -> "Vicinity": | ||
| """ | ||
| Load a Vicinity instance from the Hugging Face Hub. | ||
|
|
||
| :param repo_id: The repository ID on the Hugging Face Hub. | ||
| :param token: Optional authentication token for private repositories. | ||
| :param **kwargs: Additional arguments passed to load_dataset. | ||
| :return: A Vicinity instance loaded from the Hub. | ||
| """ | ||
| try: | ||
| from datasets import load_dataset | ||
| from huggingface_hub import snapshot_download | ||
| except ImportError: | ||
| raise _HUB_IMPORT_ERROR | ||
|
|
||
| # Load dataset and extract items and vectors | ||
| dataset = load_dataset(repo_id, token=token, split="train", **kwargs) | ||
| if "items" in dataset.column_names: | ||
| items = dataset["items"] | ||
| else: | ||
| # Create items from all columns except 'vectors' | ||
| items = [] | ||
| columns = [col for col in dataset.column_names if col != "vectors"] | ||
| for i in range(len(dataset)): | ||
| items.append({col: dataset[col][i] for col in columns}) | ||
| has_vectors = "vectors" in dataset.column_names | ||
| vector_store = BasicVectorStore(vectors=dataset["vectors"]) if has_vectors else None | ||
|
|
||
| # Download and load config and backend | ||
| repo_path = Path(snapshot_download(repo_id=repo_id, token=token, repo_type="dataset")) | ||
| with open(repo_path / "config.json") as f: | ||
| config = json.load(f) | ||
| model_name_or_path = config.pop("model_name_or_path") | ||
|
|
||
| print(_MODEL_NAME_OR_PATH_PRINT_STATEMENT.format(model_name_or_path=model_name_or_path)) | ||
| backend_type = Backend(config["backend_type"]) | ||
| backend = get_backend_class(backend_type).load(repo_path / "backend") | ||
|
|
||
| return cls(items=items, backend=backend, metadata=config["metadata"], vector_store=vector_store) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.