Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: embed_image for liteLLM #458

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/ragbits-core/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CHANGELOG

## Unreleased
- image embeddings with liteLLM for vertex_ai models

## 0.12.0 (2025-03-25)
- Allow Prompt class to accept the asynchronous response_parser. Change the signature of parse_response method.
Expand Down
76 changes: 76 additions & 0 deletions packages/ragbits-core/src/ragbits/core/embeddings/litellm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import asyncio
import base64

import litellm
from litellm.main import VertexMultimodalEmbedding
from litellm.types.llms.vertex_ai import Instance, InstanceImage

from ragbits.core.audit import trace
from ragbits.core.embeddings import Embedder
Expand Down Expand Up @@ -110,3 +115,74 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
outputs.total_tokens = response.usage.total_tokens

return outputs.embeddings

def image_support(self) -> bool: # noqa: PLR6301
"""
Check if the model supports image embeddings.

Returns:
True if the model supports image embeddings, False otherwise.
"""
# If not in our dictionary, we'll try a more dynamic approach
model_name = self.model.replace("vertexai/", "").lower()

# Check against known supported models
supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS

return model_name in supported_models

async def process_image(self, image_instance: Instance, options: Options) -> list[float]:
"""
Embeds a single image from the given instance.

Args:
image_instance: Instance of the image to embed.
options: Additional options to pass to the Lite LLM API.

Returns:
list of floats representing the embedded image.
"""
response = await litellm.aembedding(
model=self.model,
input=image_instance,
api_base=self.api_base,
api_key=self.api_key,
api_version=self.api_version,
**options.dict(),
)
return response.data[0].embedding

async def embed_image(self, images: list[bytes], options: Options | None = None) -> list[list[float]]:
"""
Embeds a list of images into a list of vectors.

Args:
images: A list of input image bytes to embed.
options: Additional settings used by the Embedder model.

Returns:
A list of embedding vectors, one for each input image.
"""
merged_options = (self.default_options | options) if options else self.default_options
with trace(
model=self.model,
api_base=self.api_base,
api_version=self.api_version,
options=merged_options.dict(),
) as outputs:
base64_images = [base64.b64encode(img).decode("utf-8") for img in images]
instances = [Instance(image=InstanceImage(bytesBase64Encoded=base64_img)) for base64_img in base64_images]
try:
embeddings = await asyncio.gather(
*[self.process_image(instance, merged_options) for instance in instances]
)

except litellm.openai.APIConnectionError as exc:
raise EmbeddingConnectionError() from exc
except litellm.openai.APIStatusError as exc:
raise EmbeddingStatusError(exc.message, exc.status_code) from exc
except litellm.openai.APIResponseValidationError as exc:
raise EmbeddingResponseError() from exc

outputs.embeddings = embeddings
return embeddings
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
model: str = "multimodalembedding",
api_base: str | None = None,
api_key: str | None = None,
concurency: int = 10,
concurrency: int = 10,
default_options: LiteLLMEmbedderOptions | None = None,
) -> None:
"""
Expand All @@ -43,7 +43,7 @@ def __init__(
model: One of the VertexAI multimodal models to be used. Default is "multimodalembedding".
api_base: The API endpoint you want to call the model with.
api_key: API key to be used. If not specified, an environment variable will be used.
concurency: The number of concurrent requests to make to the API.
concurrency: The number of concurrent requests to make to the API.
default_options: Additional options to pass to the API.

Raises:
Expand All @@ -60,7 +60,7 @@ def __init__(
self.model = model
self.api_base = api_base
self.api_key = api_key
self.concurency = concurency
self.concurrency = concurrency

supported_models = VertexMultimodalEmbedding().SUPPORTED_MULTIMODAL_EMBEDDING_MODELS
if model not in supported_models:
Expand Down Expand Up @@ -90,7 +90,7 @@ async def _embed(self, data: list[dict], options: LiteLLMEmbedderOptions | None
api_base=self.api_base,
options=merged_options.dict(),
) as outputs:
semaphore = asyncio.Semaphore(self.concurency)
semaphore = asyncio.Semaphore(self.concurrency)
try:
response = await asyncio.gather(
*[self._call_litellm(instance, semaphore, merged_options) for instance in data],
Expand Down
Loading