-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Create abstraction for multi-modal LLM
Change-Id: I5f90dc267ac778e4e7e918f2db2dee460c465c40
- Loading branch information
1 parent
f43b760
commit 7081bff
Showing
4 changed files
with
446 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
"""Module for file management and interaction with LLMs.""" | ||
|
||
import os | ||
import time | ||
from typing import Callable | ||
import google | ||
import google.api_core | ||
import google.api_core.exceptions | ||
import google.generativeai as google_genai | ||
|
||
|
||
class File: | ||
"""Represents a file to be used with a multimodal LLM. | ||
This class encapsulates the name/path of a local file and provides a mechanism | ||
for registering cleanup callbacks. These callbacks are executed | ||
when the `cleanup` method is called, typically after the file has | ||
been processed by the LLM. | ||
""" | ||
|
||
def __init__(self, name: str | os.PathLike[str]): | ||
self.name = name | ||
self._cleanup_callbacks: list[Callable[[], None]] = [] | ||
|
||
def add_cleanup_callback(self, callback: Callable[[], None]) -> None: | ||
"""Adds a callback function to be executed during cleanup. | ||
The provided callback function will be called when the `cleanup` method | ||
is invoked on the `File` object. This allows for registering actions | ||
to be performed during file cleanup, such as deleting temporary files | ||
or releasing resources. | ||
Args: | ||
callback: A callable function that takes no arguments and returns None. | ||
This function will be executed during cleanup. | ||
""" | ||
self._cleanup_callbacks.append(callback) | ||
|
||
def cleanup(self) -> None: | ||
"""Executes all registered cleanup callbacks. | ||
Callbacks are added using the `add_cleanup_callback` method. | ||
""" | ||
for cleanup_callback in self._cleanup_callbacks: | ||
cleanup_callback() | ||
|
||
|
||
class GeminiFileHandler: | ||
"""Handles file interactions with Gemini.""" | ||
|
||
def _convert_to_gemini_file_name(self, file_name: str) -> str: | ||
return "".join(c for c in file_name if c.isalnum()).lower() | ||
|
||
def upload(self, file: File) -> google_genai.types.File: | ||
"""Uploads the given file to Gemini. | ||
Args: | ||
file: The `File` object to upload. | ||
Returns: | ||
A `google_genai.types.File` object representing the uploaded file. | ||
""" | ||
file_name = self._convert_to_gemini_file_name(file.name) | ||
return google_genai.upload_file(path=file.name, name=file_name) | ||
|
||
def get(self, file: File) -> google_genai.types.File | None: | ||
"""Attempts to retrieve the file from Gemini. | ||
Args: | ||
file: The `File` object representing the file to retrieve. | ||
Returns: | ||
A `google_genai.types.File` object if the file is found, otherwise None. | ||
""" | ||
try: | ||
file_name = self._convert_to_gemini_file_name(file.name) | ||
return google_genai.get_file(file_name) | ||
# Exception for no access and file not found. | ||
except google.api_core.exceptions.PermissionDenied: | ||
return None | ||
|
||
def wait_for_processing( | ||
self, file: google_genai.types.File | ||
) -> google_genai.types.File: | ||
"""Waits for a Gemini file to finish processing. | ||
This method polls the status of the given file until it is no longer in the | ||
"PROCESSING" state. It retrieves the current state at 10-second intervals. | ||
Args: | ||
file: The `google_genai.types.File` object to wait for. | ||
Returns: | ||
The `google_genai.types.File` object, updated with its final state. | ||
""" | ||
gemini_file = file | ||
while True: | ||
if gemini_file.state.name == "PROCESSING": | ||
time.sleep(10) | ||
gemini_file = google_genai.get_file(gemini_file.name) | ||
else: | ||
break | ||
return gemini_file | ||
|
||
def prepare(self, file: File) -> google_genai.types.File: | ||
"""Prepares a file for use with Gemini. | ||
This method uploads the file (or retrieves it if it already exists) to the | ||
Gemini file storage. The method returns when the file is ready to be used | ||
with Gemini. | ||
Args: | ||
file: The `File` object to prepare. | ||
Returns: | ||
A `google_genai.types.File` object ready for use with Gemini. | ||
""" | ||
gemini_file = self.get(file) | ||
if not gemini_file: | ||
gemini_file = self.upload(file) | ||
file.add_cleanup_callback(gemini_file.delete) | ||
return self.wait_for_processing(gemini_file) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
"""Unit tests for the file_io module.""" | ||
|
||
import time | ||
import unittest | ||
from unittest import mock | ||
|
||
import file_io | ||
import google | ||
import google.api_core | ||
import google.api_core.exceptions | ||
import google.generativeai as google_genai | ||
|
||
|
||
class FileTest(unittest.TestCase): | ||
|
||
def test_cleanup_calls_callbacks(self): | ||
file = file_io.File("name") | ||
mock_callback1 = mock.MagicMock() | ||
mock_callback2 = mock.MagicMock() | ||
file.add_cleanup_callback(mock_callback1) | ||
file.add_cleanup_callback(mock_callback2) | ||
|
||
file.cleanup() | ||
|
||
mock_callback1.assert_called_once() | ||
mock_callback2.assert_called_once() | ||
|
||
|
||
class GeminiFileHandlerTest(unittest.TestCase): | ||
"""Tests the GeminiFileHandler class.""" | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.file_handler = file_io.GeminiFileHandler() | ||
|
||
@mock.patch.object(google_genai, "upload_file", autospec=True) | ||
def test_upload_calls_gemini_upload_with_valid_gemini_filename( | ||
self, mock_upload_file | ||
): | ||
file = file_io.File("file/path/file_name.txt") | ||
|
||
self.file_handler.upload(file) | ||
|
||
mock_upload_file.assert_called_once_with( | ||
path="file/path/file_name.txt", name="filepathfilenametxt" | ||
) | ||
|
||
@mock.patch.object(google_genai, "get_file", autospec=True) | ||
def test_get_calls_gemini_get_file(self, mock_get_file): | ||
file = file_io.File("file/name.txt") | ||
|
||
self.file_handler.get(file) | ||
|
||
mock_get_file.assert_called_once_with("filenametxt") | ||
|
||
@mock.patch.object(google_genai, "get_file", autospec=True) | ||
def test_get_returns_none_if_file_not_found(self, mock_get_file): | ||
mock_get_file.side_effect = google.api_core.exceptions.PermissionDenied( | ||
"File not found" | ||
) | ||
file = file_io.File("file/name.txt") | ||
|
||
gemini_file = self.file_handler.get(file) | ||
|
||
self.assertIsNone(gemini_file) | ||
|
||
@mock.patch.object(time, "sleep", autospec=True) | ||
@mock.patch.object(google_genai, "get_file", autospec=True) | ||
def test_wait_for_processing(self, mock_get_file, _): | ||
file = mock.MagicMock(spec=google_genai.types.File) | ||
file.state.name = "PROCESSING" | ||
mock_get_file.return_value.state.name = "COMPLETED" | ||
|
||
gemini_file = self.file_handler.wait_for_processing(file) | ||
|
||
self.assertEqual(gemini_file.state.name, "COMPLETED") | ||
|
||
@mock.patch.object(time, "sleep", autospec=True) | ||
@mock.patch.object(google_genai, "upload_file", autospec=True) | ||
@mock.patch.object(google_genai, "get_file", autospec=True) | ||
def test_prepare_gets_file_if_exists( | ||
self, mock_get_file, mock_upload_file, _ | ||
): | ||
"""Tests prepare retrieves the file if it exists in Gemini. | ||
Verifies that if the file is found in Gemini, it's retrieved directly | ||
and not uploaded again. | ||
Args: | ||
mock_get_file: Mock for Gemini's get_file method. | ||
mock_upload_file: Mock for Gemini's upload_file method. | ||
_: unused sleep function mock. | ||
""" | ||
file = file_io.File("file/name.txt") | ||
mock_gemini_file = mock.MagicMock() | ||
mock_get_file.return_value = mock_gemini_file | ||
mock_gemini_file.state.name = "COMPLETED" | ||
|
||
prepared_file = self.file_handler.prepare(file) | ||
|
||
self.assertEqual(prepared_file, mock_gemini_file) | ||
mock_upload_file.assert_not_called() | ||
|
||
@mock.patch.object(time, "sleep", autospec=True) | ||
@mock.patch.object(google_genai, "upload_file", autospec=True) | ||
@mock.patch.object(google_genai, "get_file", autospec=True) | ||
def test_prepare_uploads_and_waits_if_file_doesnt_exist( | ||
self, mock_get_file, mock_upload_file, _ | ||
): | ||
"""Tests prepare uploads the file if it doesn't exist in Gemini. | ||
Verifies that if the file is not found in Gemini, it's uploaded, | ||
the processing is waited for. | ||
Args: | ||
mock_get_file: Mock for Gemini's get_file method. | ||
mock_upload_file: Mock for Gemini's upload_file method. | ||
_: unused sleep function mock. | ||
""" | ||
file = file_io.File("file/name.txt") | ||
mock_get_file.side_effect = google.api_core.exceptions.PermissionDenied( | ||
"File not found" | ||
) | ||
mock_gemini_file = mock.MagicMock() | ||
mock_upload_file.return_value = mock_gemini_file | ||
mock_gemini_file.state.name = "COMPLETED" | ||
|
||
prepared_file = self.file_handler.prepare(file) | ||
|
||
self.assertEqual(prepared_file, mock_gemini_file) | ||
mock_upload_file.assert_called_once() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
"""Module for interacting with multimodal large language models (LLMs). | ||
This module provides abstractions for working with LLMs that can process | ||
multimodal inputs (text, images, videos, etc.). | ||
""" | ||
|
||
import os | ||
from typing import Protocol | ||
import file_io | ||
import google.generativeai as google_genai | ||
|
||
|
||
PromptPart = str | file_io.File | ||
|
||
|
||
class MultiModalLLMAdapater(Protocol): | ||
|
||
def generate( | ||
self, prompt_parts: list[PromptPart], response_schema, temperature: float | ||
) -> str: | ||
pass | ||
|
||
|
||
class GeminiLLMAdapter(MultiModalLLMAdapater): | ||
"""Adapter for interacting with the Gemini LLM. | ||
This adapter implements the `MultiModalLLMAdapater` protocol, providing | ||
a way to interact with the Gemini LLM using a consistent interface. | ||
It handles prompt construction, including incorporating files (images, videos) | ||
and managing the interaction with the Gemini API. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: google_genai.GenerativeModel, | ||
file_handler: file_io.GeminiFileHandler = file_io.GeminiFileHandler(), | ||
): | ||
self._model = model | ||
self.file_handler = file_handler | ||
|
||
def _parse_prompt_part( | ||
self, prompt_part: PromptPart | ||
) -> google_genai.types.ContentType: | ||
if isinstance(prompt_part, str): | ||
return prompt_part | ||
|
||
return self.file_handler.prepare(prompt_part) | ||
|
||
def generate( | ||
self, | ||
prompt_parts: list[PromptPart], | ||
response_schema, | ||
temperature: float, | ||
) -> str: | ||
content_types = [self._parse_prompt_part(part) for part in prompt_parts] | ||
response = self._model.generate_content( | ||
content_types, | ||
generation_config=google_genai.GenerationConfig( | ||
response_mime_type="application/json" | ||
if response_schema != str | ||
else None, | ||
response_schema=response_schema if response_schema != str else None, | ||
temperature=temperature, | ||
), | ||
) | ||
return response.text | ||
|
||
|
||
class MultiModalLLM: | ||
|
||
def __init__(self, adapter: MultiModalLLMAdapater): | ||
self.adapter = adapter | ||
|
||
def generate( | ||
self, | ||
prompt_parts: list[PromptPart], | ||
response_schema, | ||
temperature: float = 1.0, | ||
) -> str: | ||
return self.adapter.generate(prompt_parts, response_schema, temperature) | ||
|
||
|
||
def create_gemini_llm( | ||
system_prompt: str = None, | ||
file_handler: file_io.GeminiFileHandler = file_io.GeminiFileHandler(), | ||
api_key: str | None = None, | ||
) -> MultiModalLLM: | ||
google_genai.configure(api_key=api_key or os.environ["GEMINI_API_KEY"]) | ||
gemini = google_genai.GenerativeModel(system_instruction=system_prompt) | ||
adapter = GeminiLLMAdapter(gemini, file_handler) | ||
return MultiModalLLM(adapter) |
Oops, something went wrong.