Skip to content

Commit 6a62cd5

Browse files
author
Damian Fastowiec
committed
add audio utils to handle model audio input
1 parent 611734c commit 6a62cd5

File tree

3 files changed

+554
-0
lines changed

3 files changed

+554
-0
lines changed

dspy/adapters/audio_utils.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import base64
2+
import os
3+
import pydantic
4+
import requests
5+
6+
from typing import Any, Dict, List, Union
7+
from urllib.parse import urlparse
8+
9+
10+
class Audio(pydantic.BaseModel):
11+
url: str
12+
13+
model_config = {
14+
"frozen": True,
15+
"str_strip_whitespace": True,
16+
"validate_assignment": True,
17+
"extra": "forbid",
18+
}
19+
20+
@pydantic.model_validator(mode="before")
21+
@classmethod
22+
def validate_input(cls, values):
23+
# Allow the model to accept either a URL string or a dictionary with a single 'url' key
24+
if isinstance(values, str):
25+
return {"url": values}
26+
elif isinstance(values, dict) and set(values.keys()) == {"url"}:
27+
return values
28+
elif isinstance(values, cls):
29+
return values.model_dump()
30+
else:
31+
raise TypeError("Expected a string URL or a dictionary with a key 'url'.")
32+
33+
@classmethod
34+
def from_url(cls, url: str, download: bool = False):
35+
return cls(url=encode_audio(url, download))
36+
37+
@classmethod
38+
def from_file(cls, file_path: str):
39+
return cls(url=encode_audio(file_path))
40+
41+
@classmethod
42+
def from_bytes(cls, audio_bytes: bytes, format: str = "wav"):
43+
return cls(url=encode_audio(audio_bytes, format=format))
44+
45+
@pydantic.model_serializer()
46+
def serialize_model(self):
47+
return "<DSPY_AUDIO_START>" + self.url + "<DSPY_AUDIO_END>"
48+
49+
def __str__(self):
50+
return self.serialize_model()
51+
52+
def __repr__(self):
53+
if "base64" in self.url:
54+
len_base64 = len(self.url.split("base64,")[1])
55+
audio_type = self.url.split(";")[0].split("/")[-1]
56+
return f"Audio(url=data:audio/{audio_type};base64,<AUDIO_BASE_64_ENCODED({str(len_base64)})>)"
57+
return f"Audio(url='{self.url}')"
58+
59+
60+
def is_url(string: str) -> bool:
61+
"""Check if a string is a valid URL."""
62+
try:
63+
result = urlparse(string)
64+
return all([result.scheme in ("http", "https"), result.netloc])
65+
except ValueError:
66+
return False
67+
68+
69+
def encode_audio(
70+
audio: Union[str, bytes, dict], download_images: bool = False, format: str = None
71+
) -> str:
72+
"""
73+
Encode an audio file to a base64 data URI.
74+
75+
Args:
76+
audio: The audio to encode. Can be a file path, URL, or data URI.
77+
download_images: Whether to download audio from URLs.
78+
format: The audio format when encoding from bytes (e.g., 'wav', 'mp3', etc.)
79+
80+
Returns:
81+
str: The data URI of the audio or the URL if download_images is False.
82+
83+
Raises:
84+
ValueError: If the audio type is not supported.
85+
"""
86+
if isinstance(audio, dict) and "url" in audio:
87+
return audio["url"]
88+
elif isinstance(audio, str):
89+
if audio.startswith("data:audio/"):
90+
# Already a data URI
91+
return audio
92+
elif os.path.isfile(audio):
93+
# File path
94+
return _encode_audio_from_file(audio)
95+
elif is_url(audio):
96+
# URL
97+
if download_images:
98+
return _encode_audio_from_url(audio)
99+
else:
100+
# Return the URL as is
101+
return audio
102+
else:
103+
print(f"Unsupported audio string: {audio}")
104+
raise ValueError(f"Unsupported audio string: {audio}")
105+
elif isinstance(audio, bytes):
106+
# Raw bytes
107+
if not format:
108+
format = "wav" # Default format
109+
return _encode_audio_from_bytes(audio, format)
110+
elif isinstance(audio, Audio):
111+
return audio.url
112+
else:
113+
print(f"Unsupported audio type: {type(audio)}")
114+
raise ValueError(f"Unsupported audio type: {type(audio)}")
115+
116+
117+
def _encode_audio_from_file(file_path: str) -> str:
118+
"""Encode an audio file from a file path to a base64 data URI."""
119+
with open(file_path, "rb") as audio_file:
120+
audio_data = audio_file.read()
121+
file_extension = _get_file_extension(file_path)
122+
encoded_audio = base64.b64encode(audio_data).decode("utf-8")
123+
return f"data:audio/{file_extension};base64,{encoded_audio}"
124+
125+
126+
def _encode_audio_from_url(audio_url: str) -> str:
127+
"""Encode an audio file from a URL to a base64 data URI."""
128+
response = requests.get(audio_url)
129+
response.raise_for_status()
130+
content_type = response.headers.get("Content-Type", "")
131+
if content_type.startswith("audio/"):
132+
file_extension = content_type.split("/")[-1]
133+
else:
134+
file_extension = _get_file_extension(audio_url) or "wav"
135+
encoded_audio = base64.b64encode(response.content).decode("utf-8")
136+
return f"data:audio/{file_extension};base64,{encoded_audio}"
137+
138+
139+
def _encode_audio_from_bytes(audio_bytes: bytes, format: str) -> str:
140+
"""Encode audio bytes to a base64 data URI."""
141+
encoded_audio = base64.b64encode(audio_bytes).decode("utf-8")
142+
return f"data:audio/{format};base64,{encoded_audio}"
143+
144+
145+
def _get_file_extension(path_or_url: str) -> str:
146+
"""Extract the file extension from a file path or URL."""
147+
extension = os.path.splitext(urlparse(path_or_url).path)[1].lstrip(".").lower()
148+
return extension or "wav" # Default to 'wav' if no extension found
149+
150+
151+
def is_audio(obj) -> bool:
152+
"""Check if the object is an audio file or a valid audio reference."""
153+
if isinstance(obj, str):
154+
if obj.startswith("data:audio/"):
155+
return True
156+
elif os.path.isfile(obj):
157+
# Could add more specific audio file extension checking here
158+
return True
159+
elif is_url(obj):
160+
return True
161+
return False

dspy/adapters/media_utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import re
2+
3+
from typing import Any, Dict, List, Union
4+
5+
6+
def try_expand_media_tags(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
7+
"""Try to expand media tags (audio or image) in the messages."""
8+
for message in messages:
9+
if "content" in message and isinstance(message["content"], str):
10+
content = message["content"]
11+
12+
# Check for audio tags first (since they also use image_url type)
13+
if "<DSPY_AUDIO_START>" in content:
14+
message["content"] = expand_media_content(content, media_type="audio")
15+
# Only check for image tags if no audio tags were found
16+
elif "<DSPY_IMAGE_START>" in content:
17+
message["content"] = expand_media_content(content, media_type="image")
18+
19+
return messages
20+
21+
22+
def expand_media_content(
23+
text: str, media_type: str
24+
) -> Union[str, List[Dict[str, Any]]]:
25+
"""Expand media tags in the text into a content list with text and media URLs.
26+
27+
Args:
28+
text: The text content that may contain media tags
29+
media_type: Either "audio" or "image"
30+
"""
31+
tag_start = f"<DSPY_{media_type.upper()}_START>"
32+
tag_end = f"<DSPY_{media_type.upper()}_END>"
33+
tag_regex = rf'"?{tag_start}(.*?){tag_end}"?'
34+
35+
# If no media tags, return original text
36+
if not re.search(tag_regex, text):
37+
return text
38+
39+
final_list = []
40+
remaining_text = text
41+
42+
while remaining_text:
43+
match = re.search(tag_regex, remaining_text)
44+
if not match:
45+
if remaining_text.strip():
46+
final_list.append({"type": "text", "text": remaining_text.strip()})
47+
break
48+
49+
# Get text before the media tag
50+
prefix = remaining_text[: match.start()].strip()
51+
if prefix:
52+
final_list.append({"type": "text", "text": prefix})
53+
54+
# Add the media URL
55+
media_url = match.group(1)
56+
mime_prefix = f"data:{media_type}/"
57+
if media_url.startswith(mime_prefix):
58+
final_list.append({"type": "image_url", "image_url": {"url": media_url}})
59+
60+
# Update remaining text
61+
remaining_text = remaining_text[match.end() :].strip()
62+
63+
return final_list

0 commit comments

Comments
 (0)