Skip to content

Commit

Permalink
Create abstraction for multi-modal LLM
Browse files Browse the repository at this point in the history
Change-Id: I5f90dc267ac778e4e7e918f2db2dee460c465c40
  • Loading branch information
chris-feldman committed Jan 17, 2025
1 parent f43b760 commit 7081bff
Show file tree
Hide file tree
Showing 4 changed files with 446 additions and 0 deletions.
122 changes: 122 additions & 0 deletions ai_metadata/file_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""Module for file management and interaction with LLMs."""

import os
import time
from typing import Callable
import google
import google.api_core
import google.api_core.exceptions
import google.generativeai as google_genai


class File:
"""Represents a file to be used with a multimodal LLM.
This class encapsulates the name/path of a local file and provides a mechanism
for registering cleanup callbacks. These callbacks are executed
when the `cleanup` method is called, typically after the file has
been processed by the LLM.
"""

def __init__(self, name: str | os.PathLike[str]):
self.name = name
self._cleanup_callbacks: list[Callable[[], None]] = []

def add_cleanup_callback(self, callback: Callable[[], None]) -> None:
"""Adds a callback function to be executed during cleanup.
The provided callback function will be called when the `cleanup` method
is invoked on the `File` object. This allows for registering actions
to be performed during file cleanup, such as deleting temporary files
or releasing resources.
Args:
callback: A callable function that takes no arguments and returns None.
This function will be executed during cleanup.
"""
self._cleanup_callbacks.append(callback)

def cleanup(self) -> None:
"""Executes all registered cleanup callbacks.
Callbacks are added using the `add_cleanup_callback` method.
"""
for cleanup_callback in self._cleanup_callbacks:
cleanup_callback()


class GeminiFileHandler:
"""Handles file interactions with Gemini."""

def _convert_to_gemini_file_name(self, file_name: str) -> str:
return "".join(c for c in file_name if c.isalnum()).lower()

def upload(self, file: File) -> google_genai.types.File:
"""Uploads the given file to Gemini.
Args:
file: The `File` object to upload.
Returns:
A `google_genai.types.File` object representing the uploaded file.
"""
file_name = self._convert_to_gemini_file_name(file.name)
return google_genai.upload_file(path=file.name, name=file_name)

def get(self, file: File) -> google_genai.types.File | None:
"""Attempts to retrieve the file from Gemini.
Args:
file: The `File` object representing the file to retrieve.
Returns:
A `google_genai.types.File` object if the file is found, otherwise None.
"""
try:
file_name = self._convert_to_gemini_file_name(file.name)
return google_genai.get_file(file_name)
# Exception for no access and file not found.
except google.api_core.exceptions.PermissionDenied:
return None

def wait_for_processing(
self, file: google_genai.types.File
) -> google_genai.types.File:
"""Waits for a Gemini file to finish processing.
This method polls the status of the given file until it is no longer in the
"PROCESSING" state. It retrieves the current state at 10-second intervals.
Args:
file: The `google_genai.types.File` object to wait for.
Returns:
The `google_genai.types.File` object, updated with its final state.
"""
gemini_file = file
while True:
if gemini_file.state.name == "PROCESSING":
time.sleep(10)
gemini_file = google_genai.get_file(gemini_file.name)
else:
break
return gemini_file

def prepare(self, file: File) -> google_genai.types.File:
"""Prepares a file for use with Gemini.
This method uploads the file (or retrieves it if it already exists) to the
Gemini file storage. The method returns when the file is ready to be used
with Gemini.
Args:
file: The `File` object to prepare.
Returns:
A `google_genai.types.File` object ready for use with Gemini.
"""
gemini_file = self.get(file)
if not gemini_file:
gemini_file = self.upload(file)
file.add_cleanup_callback(gemini_file.delete)
return self.wait_for_processing(gemini_file)
135 changes: 135 additions & 0 deletions ai_metadata/file_io_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Unit tests for the file_io module."""

import time
import unittest
from unittest import mock

import file_io
import google
import google.api_core
import google.api_core.exceptions
import google.generativeai as google_genai


class FileTest(unittest.TestCase):

def test_cleanup_calls_callbacks(self):
file = file_io.File("name")
mock_callback1 = mock.MagicMock()
mock_callback2 = mock.MagicMock()
file.add_cleanup_callback(mock_callback1)
file.add_cleanup_callback(mock_callback2)

file.cleanup()

mock_callback1.assert_called_once()
mock_callback2.assert_called_once()


class GeminiFileHandlerTest(unittest.TestCase):
"""Tests the GeminiFileHandler class."""

def setUp(self):
super().setUp()
self.file_handler = file_io.GeminiFileHandler()

@mock.patch.object(google_genai, "upload_file", autospec=True)
def test_upload_calls_gemini_upload_with_valid_gemini_filename(
self, mock_upload_file
):
file = file_io.File("file/path/file_name.txt")

self.file_handler.upload(file)

mock_upload_file.assert_called_once_with(
path="file/path/file_name.txt", name="filepathfilenametxt"
)

@mock.patch.object(google_genai, "get_file", autospec=True)
def test_get_calls_gemini_get_file(self, mock_get_file):
file = file_io.File("file/name.txt")

self.file_handler.get(file)

mock_get_file.assert_called_once_with("filenametxt")

@mock.patch.object(google_genai, "get_file", autospec=True)
def test_get_returns_none_if_file_not_found(self, mock_get_file):
mock_get_file.side_effect = google.api_core.exceptions.PermissionDenied(
"File not found"
)
file = file_io.File("file/name.txt")

gemini_file = self.file_handler.get(file)

self.assertIsNone(gemini_file)

@mock.patch.object(time, "sleep", autospec=True)
@mock.patch.object(google_genai, "get_file", autospec=True)
def test_wait_for_processing(self, mock_get_file, _):
file = mock.MagicMock(spec=google_genai.types.File)
file.state.name = "PROCESSING"
mock_get_file.return_value.state.name = "COMPLETED"

gemini_file = self.file_handler.wait_for_processing(file)

self.assertEqual(gemini_file.state.name, "COMPLETED")

@mock.patch.object(time, "sleep", autospec=True)
@mock.patch.object(google_genai, "upload_file", autospec=True)
@mock.patch.object(google_genai, "get_file", autospec=True)
def test_prepare_gets_file_if_exists(
self, mock_get_file, mock_upload_file, _
):
"""Tests prepare retrieves the file if it exists in Gemini.
Verifies that if the file is found in Gemini, it's retrieved directly
and not uploaded again.
Args:
mock_get_file: Mock for Gemini's get_file method.
mock_upload_file: Mock for Gemini's upload_file method.
_: unused sleep function mock.
"""
file = file_io.File("file/name.txt")
mock_gemini_file = mock.MagicMock()
mock_get_file.return_value = mock_gemini_file
mock_gemini_file.state.name = "COMPLETED"

prepared_file = self.file_handler.prepare(file)

self.assertEqual(prepared_file, mock_gemini_file)
mock_upload_file.assert_not_called()

@mock.patch.object(time, "sleep", autospec=True)
@mock.patch.object(google_genai, "upload_file", autospec=True)
@mock.patch.object(google_genai, "get_file", autospec=True)
def test_prepare_uploads_and_waits_if_file_doesnt_exist(
self, mock_get_file, mock_upload_file, _
):
"""Tests prepare uploads the file if it doesn't exist in Gemini.
Verifies that if the file is not found in Gemini, it's uploaded,
the processing is waited for.
Args:
mock_get_file: Mock for Gemini's get_file method.
mock_upload_file: Mock for Gemini's upload_file method.
_: unused sleep function mock.
"""
file = file_io.File("file/name.txt")
mock_get_file.side_effect = google.api_core.exceptions.PermissionDenied(
"File not found"
)
mock_gemini_file = mock.MagicMock()
mock_upload_file.return_value = mock_gemini_file
mock_gemini_file.state.name = "COMPLETED"

prepared_file = self.file_handler.prepare(file)

self.assertEqual(prepared_file, mock_gemini_file)
mock_upload_file.assert_called_once()


if __name__ == "__main__":
unittest.main()
91 changes: 91 additions & 0 deletions ai_metadata/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Module for interacting with multimodal large language models (LLMs).
This module provides abstractions for working with LLMs that can process
multimodal inputs (text, images, videos, etc.).
"""

import os
from typing import Protocol
import file_io
import google.generativeai as google_genai


PromptPart = str | file_io.File


class MultiModalLLMAdapater(Protocol):

def generate(
self, prompt_parts: list[PromptPart], response_schema, temperature: float
) -> str:
pass


class GeminiLLMAdapter(MultiModalLLMAdapater):
"""Adapter for interacting with the Gemini LLM.
This adapter implements the `MultiModalLLMAdapater` protocol, providing
a way to interact with the Gemini LLM using a consistent interface.
It handles prompt construction, including incorporating files (images, videos)
and managing the interaction with the Gemini API.
"""

def __init__(
self,
model: google_genai.GenerativeModel,
file_handler: file_io.GeminiFileHandler = file_io.GeminiFileHandler(),
):
self._model = model
self.file_handler = file_handler

def _parse_prompt_part(
self, prompt_part: PromptPart
) -> google_genai.types.ContentType:
if isinstance(prompt_part, str):
return prompt_part

return self.file_handler.prepare(prompt_part)

def generate(
self,
prompt_parts: list[PromptPart],
response_schema,
temperature: float,
) -> str:
content_types = [self._parse_prompt_part(part) for part in prompt_parts]
response = self._model.generate_content(
content_types,
generation_config=google_genai.GenerationConfig(
response_mime_type="application/json"
if response_schema != str
else None,
response_schema=response_schema if response_schema != str else None,
temperature=temperature,
),
)
return response.text


class MultiModalLLM:

def __init__(self, adapter: MultiModalLLMAdapater):
self.adapter = adapter

def generate(
self,
prompt_parts: list[PromptPart],
response_schema,
temperature: float = 1.0,
) -> str:
return self.adapter.generate(prompt_parts, response_schema, temperature)


def create_gemini_llm(
system_prompt: str = None,
file_handler: file_io.GeminiFileHandler = file_io.GeminiFileHandler(),
api_key: str | None = None,
) -> MultiModalLLM:
google_genai.configure(api_key=api_key or os.environ["GEMINI_API_KEY"])
gemini = google_genai.GenerativeModel(system_instruction=system_prompt)
adapter = GeminiLLMAdapter(gemini, file_handler)
return MultiModalLLM(adapter)
Loading

0 comments on commit 7081bff

Please sign in to comment.