Skip to content

Commit 611734c

Browse files
author
Damian Fastowiec
committed
add audio utils to handle model audio input
1 parent b37f282 commit 611734c

File tree

6 files changed

+31
-66
lines changed

6 files changed

+31
-66
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from dspy.evaluate import Evaluate # isort: skip
1010
from dspy.clients import * # isort: skip
11-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, Image # isort: skip
11+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, Image, Audio # isort: skip
1212
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1313
from dspy.utils.asyncify import asyncify
1414
from dspy.utils.saving import load

dspy/adapters/__init__.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
from dspy.adapters.base import Adapter
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
4-
from dspy.adapters.image_utils import Image
4+
from dspy.adapters.image_utils import Image, encode_image, is_image
5+
from dspy.adapters.audio_utils import Audio, encode_audio, is_audio
6+
from dspy.adapters.media_utils import try_expand_media_tags
57

68
__all__ = [
7-
"Adapter",
8-
"ChatAdapter",
9-
"JSONAdapter",
10-
"Image",
9+
'Adapter',
10+
'ChatAdapter',
11+
'JSONAdapter',
12+
'Image',
13+
'Audio',
14+
'encode_image',
15+
'encode_audio',
16+
'is_image',
17+
'is_audio',
18+
'try_expand_media_tags',
1119
]

dspy/adapters/chat_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from pydantic.fields import FieldInfo
1212

1313
from dspy.adapters.base import Adapter
14-
from dspy.adapters.image_utils import try_expand_image_tags
1514
from dspy.adapters.utils import format_field_value, get_annotation_name, parse_value
1615
from dspy.signatures.field import OutputField
1716
from dspy.signatures.signature import Signature, SignatureMeta
1817
from dspy.signatures.utils import get_dspy_field_type
18+
from dspy.adapters.media_utils import try_expand_media_tags
1919

2020
field_header_pattern = re.compile(r"\[\[ ## (\w+) ## \]\]")
2121

@@ -54,7 +54,7 @@ def format(self, signature: Signature, demos: list[dict[str, Any]], inputs: dict
5454
messages.append(format_turn(signature, demo, role="assistant", incomplete=demo in incomplete_demos))
5555

5656
messages.append(format_turn(signature, inputs, role="user"))
57-
messages = try_expand_image_tags(messages)
57+
messages = try_expand_media_tags(messages)
5858
return messages
5959

6060
def parse(self, signature, completion):

dspy/adapters/image_utils.py

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@
1818

1919
class Image(pydantic.BaseModel):
2020
url: str
21-
21+
2222
model_config = {
23-
'frozen': True,
24-
'str_strip_whitespace': True,
25-
'validate_assignment': True,
26-
'extra': 'forbid',
23+
"frozen": True,
24+
"str_strip_whitespace": True,
25+
"validate_assignment": True,
26+
"extra": "forbid",
2727
}
28-
28+
2929
@pydantic.model_validator(mode="before")
3030
@classmethod
3131
def validate_input(cls, values):
@@ -68,6 +68,7 @@ def __repr__(self):
6868
return f"Image(url=data:image/{image_type};base64,<IMAGE_BASE_64_ENCODED({str(len_base64)})>)"
6969
return f"Image(url='{self.url}')"
7070

71+
7172
def is_url(string: str) -> bool:
7273
"""Check if a string is a valid URL."""
7374
try:
@@ -77,7 +78,9 @@ def is_url(string: str) -> bool:
7778
return False
7879

7980

80-
def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False) -> str:
81+
def encode_image(
82+
image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False
83+
) -> str:
8184
"""
8285
Encode an image to a base64 data URI.
8386
@@ -150,7 +153,8 @@ def _encode_image_from_url(image_url: str) -> str:
150153
encoded_image = base64.b64encode(response.content).decode("utf-8")
151154
return f"data:image/{file_extension};base64,{encoded_image}"
152155

153-
def _encode_pil_image(image: 'PILImage') -> str:
156+
157+
def _encode_pil_image(image: "PILImage") -> str:
154158
"""Encode a PIL Image object to a base64 data URI."""
155159
buffered = io.BytesIO()
156160
file_extension = (image.format or "PNG").lower()
@@ -177,52 +181,3 @@ def is_image(obj) -> bool:
177181
elif is_url(obj):
178182
return True
179183
return False
180-
181-
def try_expand_image_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
182-
"""Try to expand image tags in the messages."""
183-
for message in messages:
184-
# NOTE: Assumption that content is a string
185-
if "content" in message and "<DSPY_IMAGE_START>" in message["content"]:
186-
message["content"] = expand_image_tags(message["content"])
187-
return messages
188-
189-
def expand_image_tags(text: str) -> Union[str, List[Dict[str, Any]]]:
190-
"""Expand image tags in the text. If there are any image tags,
191-
turn it from a content string into a content list of texts and image urls.
192-
193-
Args:
194-
text: The text content that may contain image tags
195-
196-
Returns:
197-
Either the original string if no image tags, or a list of content dicts
198-
with text and image_url entries
199-
"""
200-
image_tag_regex = r'"?<DSPY_IMAGE_START>(.*?)<DSPY_IMAGE_END>"?'
201-
202-
# If no image tags, return original text
203-
if not re.search(image_tag_regex, text):
204-
return text
205-
206-
final_list = []
207-
remaining_text = text
208-
209-
while remaining_text:
210-
match = re.search(image_tag_regex, remaining_text)
211-
if not match:
212-
if remaining_text.strip():
213-
final_list.append({"type": "text", "text": remaining_text.strip()})
214-
break
215-
216-
# Get text before the image tag
217-
prefix = remaining_text[:match.start()].strip()
218-
if prefix:
219-
final_list.append({"type": "text", "text": prefix})
220-
221-
# Add the image
222-
image_url = match.group(1)
223-
final_list.append({"type": "image_url", "image_url": {"url": image_url}})
224-
225-
# Update remaining text
226-
remaining_text = remaining_text[match.end():].strip()
227-
228-
return final_list

dspy/adapters/json_adapter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from dspy.adapters.base import Adapter
1616
from dspy.adapters.image_utils import Image
17+
from dspy.adapters.audio_utils import Audio
1718
from dspy.adapters.utils import parse_value, format_field_value, get_annotation_name, serialize_for_json
1819
from dspy.signatures.signature import SignatureMeta
1920
from dspy.signatures.utils import get_dspy_field_type
@@ -131,7 +132,7 @@ def _format_field_value(field_info: FieldInfo, value: Any) -> str:
131132
The formatted value of the field, represented as a string.
132133
"""
133134
# TODO: Wasnt this easy to fix?
134-
if field_info.annotation is Image:
135+
if field_info.annotation is Image or field_info.annotation is Audio:
135136
raise NotImplementedError("Images are not yet supported in JSON mode.")
136137

137138
return format_field_value(field_info=field_info, value=value)

dspy/signatures/signature.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class MySignature(dspy.Signature):
2828
from pydantic.fields import FieldInfo
2929

3030
from dspy.adapters.image_utils import Image # noqa: F401
31+
from dspy.adapters.audio_utils import Audio # noqa: F401
3132
from dspy.signatures.field import InputField, OutputField
3233

3334

0 commit comments

Comments
 (0)