Skip to content

Commit 0ad433c

Browse files
maciejmajekrachwalk
authored andcommitted
feat: add ElevenLabsTTS
1 parent 746997d commit 0ad433c

File tree

2 files changed

+114
-1
lines changed

2 files changed

+114
-1
lines changed

src/rai_tts/rai_tts/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .base import TTSModel, TTSModelError
16+
from .elevenlabs_tts import ElevenLabsTTS
1617
from .open_tts import OpenTTS
1718

18-
__all__ = ["OpenTTS", "TTSModel", "TTSModelError"]
19+
__all__ = ["ElevenLabsTTS", "OpenTTS", "TTSModel", "TTSModelError"]
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright (C) 2024 Robotec.AI
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from io import BytesIO
17+
from typing import Tuple
18+
19+
from elevenlabs.client import ElevenLabs
20+
from elevenlabs.types import Voice
21+
from elevenlabs.types.voice_settings import VoiceSettings
22+
from pydub import AudioSegment
23+
24+
from rai_tts.models import TTSModel, TTSModelError
25+
26+
27+
class ElevenLabsTTS(TTSModel):
28+
"""
29+
A text-to-speech (TTS) model interface for ElevenLabs.
30+
31+
Parameters
32+
----------
33+
voice : str, optional
34+
The voice model to use.
35+
base_url : str, optional
36+
The API endpoint for the ElevenLabs API, by default None.
37+
"""
38+
39+
def __init__(
40+
self,
41+
voice: str,
42+
base_url: str | None = None,
43+
):
44+
api_key = os.getenv(key="ELEVENLABS_API_KEY")
45+
if api_key is None:
46+
raise TTSModelError("ELEVENLABS_API_KEY environment variable is not set.")
47+
48+
self.client = ElevenLabs(base_url=base_url, api_key=api_key)
49+
self.voice_settings = VoiceSettings(
50+
stability=0.7,
51+
similarity_boost=0.5,
52+
)
53+
54+
voices = self.client.voices.get_all().voices
55+
voice_id = next((v.voice_id for v in voices if v.name == voice), None)
56+
if voice_id is None:
57+
raise TTSModelError(f"Voice {voice} not found")
58+
self.voice = Voice(voice_id=voice_id, settings=self.voice_settings)
59+
60+
def get_speech(self, text: str) -> AudioSegment:
61+
"""
62+
Converts text into speech using the ElevenLabs API.
63+
64+
Parameters
65+
----------
66+
text : str
67+
The input text to be converted into speech.
68+
69+
Returns
70+
-------
71+
AudioSegment
72+
The generated speech as an `AudioSegment` object.
73+
74+
Raises
75+
------
76+
TTSModelError
77+
If there is an issue with the request or the ElevenLabs API is unreachable.
78+
If the response does not contain valid audio data.
79+
"""
80+
try:
81+
response = self.client.generate(
82+
text=text,
83+
voice=self.voice,
84+
optimize_streaming_latency=4,
85+
)
86+
audio_data = b"".join(response)
87+
except Exception as e:
88+
raise TTSModelError(f"Error occurred while fetching audio: {e}") from e
89+
90+
# Load audio into memory (ElevenLabs returns MP3)
91+
audio_segment = AudioSegment.from_mp3(BytesIO(audio_data))
92+
return audio_segment
93+
94+
def get_tts_params(self) -> Tuple[int, int]:
95+
"""
96+
Returns TTS sampling rate and channels.
97+
98+
The information is retrieved by running a sample transcription request, to ensure that the information will be accurate for generation.
99+
100+
Returns
101+
-------
102+
Tuple[int, int]
103+
sample rate, channels
104+
105+
Raises
106+
------
107+
TTSModelError
108+
If there is an issue with the request or the ElevenLabs API is unreachable.
109+
If the response does not contain valid audio data.
110+
"""
111+
data = self.get_speech("A")
112+
return data.frame_rate, 1

0 commit comments

Comments
 (0)