Skip to content

Commit 96465e6

Browse files
authored
Merge pull request #1 from liberate-org/exp/vad
added Silerio VAD and updated OpenAI to latest SDK
2 parents 6f7e9cd + 0f8a766 commit 96465e6

File tree

11 files changed

+2443
-1474
lines changed

11 files changed

+2443
-1474
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ benchmark_results/
1111
private.key
1212
dump.rdb
1313
.idea
14+
.venv

poetry.lock

Lines changed: 2117 additions & 1438 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ homepage = "https://github.com/vocodedev/vocode-python"
1111
python = ">=3.8.1,<3.12"
1212
pydub = "^0.25.1"
1313
nltk = "^3.8.1"
14-
openai = "^0.27.8"
14+
openai = "1.12.0"
1515
sounddevice = "^0.4.6"
1616
azure-cognitiveservices-speech = "^1.27.0"
1717
websockets = "^11.0.2"
@@ -42,6 +42,9 @@ langchain = "^0.0.198"
4242
google-cloud-aiplatform = {version = "^1.26.0", optional = true}
4343
miniaudio = "^1.59"
4444
boto3 = "^1.28.28"
45+
pandas = "2.0.3"
46+
torch = "2.1.1"
47+
torchaudio = "2.1.1"
4548

4649

4750
[tool.poetry.group.lint.dependencies]

vocode/streaming/agent/chat_gpt_agent.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@
22

33
from typing import Any, Dict, List, Optional, Tuple, Union
44

5-
import openai
5+
from openai import (
6+
OpenAI,
7+
AsyncOpenAI,
8+
AsyncAzureOpenAI,
9+
AzureOpenAI,
10+
)
611
from typing import AsyncGenerator, Optional, Tuple
712

813
import logging
@@ -37,16 +42,26 @@ def __init__(
3742
agent_config=agent_config, action_factory=action_factory, logger=logger
3843
)
3944
if agent_config.azure_params:
40-
openai.api_type = agent_config.azure_params.api_type
41-
openai.api_base = getenv("AZURE_OPENAI_API_BASE")
42-
openai.api_version = agent_config.azure_params.api_version
43-
openai.api_key = getenv("AZURE_OPENAI_API_KEY")
45+
self.openaiAsyncClient = AsyncAzureOpenAI(
46+
api_version = agent_config.azure_params.api_version,
47+
base_url = getenv("AZURE_OPENAI_API_BASE"),
48+
api_key = getenv("AZURE_OPENAI_API_KEY")
49+
)
50+
self.openaiSyncClient = AzureOpenAI(
51+
api_version = agent_config.azure_params.api_version,
52+
base_url = getenv("AZURE_OPENAI_API_BASE"),
53+
api_key = getenv("AZURE_OPENAI_API_KEY")
54+
)
4455
else:
45-
openai.api_type = "open_ai"
46-
openai.api_base = "https://api.openai.com/v1"
47-
openai.api_version = None
48-
openai.api_key = openai_api_key or getenv("OPENAI_API_KEY")
49-
if not openai.api_key:
56+
self.openaiAsyncClient = AsyncOpenAI(
57+
base_url = "https://api.openai.com/v1",
58+
api_key = openai_api_key or getenv("OPENAI_API_KEY")
59+
)
60+
self.openaiSyncClient = OpenAI(
61+
base_url = "https://api.openai.com/v1",
62+
api_key = openai_api_key or getenv("OPENAI_API_KEY")
63+
)
64+
if not self.openaiAsyncClient.api_key or not self.openaiSyncClient.api_key:
5065
raise ValueError("OPENAI_API_KEY must be set in environment or passed in")
5166
self.first_response = (
5267
self.create_first_response(agent_config.expected_first_prompt)
@@ -104,7 +119,7 @@ def create_first_response(self, first_prompt):
104119
]
105120

106121
parameters = self.get_chat_parameters(messages)
107-
return openai.ChatCompletion.create(**parameters)
122+
return self.openaiSyncClient.chat.completions.create(**parameters)
108123

109124
def attach_transcript(self, transcript: Transcript):
110125
self.transcript = transcript
@@ -126,7 +141,8 @@ async def respond(
126141
text = self.first_response
127142
else:
128143
chat_parameters = self.get_chat_parameters()
129-
chat_completion = await openai.ChatCompletion.acreate(**chat_parameters)
144+
# chat_completion = await openai.ChatCompletion.acreate(**chat_parameters)
145+
chat_completion = await self.openaiAsyncClient.chat.completions.create(**chat_parameters)
130146
text = chat_completion.choices[0].message.content
131147
self.logger.debug(f"LLM response: {text}")
132148
return text, False
@@ -172,7 +188,8 @@ async def generate_response(
172188
else:
173189
chat_parameters = self.get_chat_parameters()
174190
chat_parameters["stream"] = True
175-
stream = await openai.ChatCompletion.acreate(**chat_parameters)
191+
# stream = await openai.ChatCompletion.acreate(**chat_parameters)
192+
stream = await self.openaiAsyncClient.chat.completions.create(**chat_parameters)
176193
async for message in collate_response_async(
177194
openai_get_tokens(stream), get_functions=True
178195
):

vocode/streaming/agent/utils.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from copy import deepcopy
22
import re
3+
import time
34
from typing import (
45
Dict,
56
Any,
@@ -12,8 +13,8 @@
1213
TypeVar,
1314
Union,
1415
)
16+
import logging
1517

16-
from openai.openai_object import OpenAIObject
1718
from vocode.streaming.models.actions import FunctionCall, FunctionFragment
1819
from vocode.streaming.models.events import Sender
1920
from vocode.streaming.models.transcript import (
@@ -31,6 +32,8 @@ async def collate_response_async(
3132
gen: AsyncIterable[Union[str, FunctionFragment]],
3233
sentence_endings: List[str] = SENTENCE_ENDINGS,
3334
get_functions: Literal[True, False] = False,
35+
logger: Optional[logging.Logger] = None,
36+
start_token_processing: Optional[float] = time.time()
3437
) -> AsyncGenerator[Union[str, FunctionCall], None]:
3538
sentence_endings_pattern = "|".join(map(re.escape, sentence_endings))
3639
list_item_ending_pattern = r"\n"
@@ -43,6 +46,10 @@ async def collate_response_async(
4346
continue
4447
if isinstance(token, str):
4548
if prev_ends_with_money and token.startswith(" "):
49+
if logger:
50+
logger.debug("Took %s to generate [%s]",
51+
time.time() - start_token_processing,
52+
buffer.strip())
4653
yield buffer.strip()
4754
buffer = ""
4855

@@ -58,6 +65,10 @@ async def collate_response_async(
5865
if not ends_with_money:
5966
to_return = buffer.strip()
6067
if to_return:
68+
if logger:
69+
logger.debug("Took %s to generate [%s]",
70+
time.time() - start_token_processing,
71+
to_return)
6172
yield to_return
6273
buffer = ""
6374
prev_ends_with_money = ends_with_money
@@ -66,35 +77,47 @@ async def collate_response_async(
6677
function_args_buffer += token.arguments
6778
to_return = buffer.strip()
6879
if to_return:
80+
if logger:
81+
logger.debug("Took %s to generate [%s]",
82+
time.time() - start_token_processing,
83+
to_return)
6984
yield to_return
7085
if function_name_buffer and get_functions:
7186
yield FunctionCall(name=function_name_buffer, arguments=function_args_buffer)
7287

7388

7489
async def openai_get_tokens(gen) -> AsyncGenerator[Union[str, FunctionFragment], None]:
7590
async for event in gen:
76-
choices = event.get("choices", [])
91+
choices = event.choices or []
7792
if len(choices) == 0:
78-
continue
93+
break
7994
choice = choices[0]
8095
if choice.finish_reason:
8196
break
82-
delta = choice.get("delta", {})
83-
if "text" in delta and delta["text"] is not None:
84-
token = delta["text"]
97+
delta = choice.delta or {}
98+
if hasattr(delta, "text") and delta.text:
99+
token = delta.text
85100
yield token
86-
if "content" in delta and delta["content"] is not None:
87-
token = delta["content"]
101+
if hasattr(delta, "content") and delta.content:
102+
token = delta.content
88103
yield token
89-
elif "function_call" in delta and delta["function_call"] is not None:
90-
yield FunctionFragment(
91-
name=delta["function_call"]["name"]
92-
if "name" in delta["function_call"]
93-
else "",
94-
arguments=delta["function_call"]["arguments"]
95-
if "arguments" in delta["function_call"]
96-
else "",
97-
)
104+
105+
elif hasattr(delta, "tool_calls") and delta.tool_calls:
106+
for tool_call in delta.tool_calls:
107+
if tool_call.function is not None:
108+
function = tool_call.function
109+
yield FunctionFragment(
110+
name =(
111+
function.name
112+
if hasattr(function, "name") and function.name
113+
else ""
114+
),
115+
arguments=(
116+
function.arguments
117+
if hasattr(function, "arguments") and function.arguments
118+
else ""
119+
)
120+
)
98121

99122

100123
def find_last_punctuation(buffer: str) -> Optional[int]:
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import logging
2+
import torch
3+
from importlib import resources as impresources
4+
5+
6+
class SileroVAD:
7+
INT16_NORM_CONST = 32768.0
8+
9+
def __init__(self, sample_rate: int, window_size: int, threshold: float = 0.5):
10+
# Silero VAD is optimized for performance on single CPU thread
11+
torch.set_num_threads(1)
12+
13+
self.logger = logging.getLogger(__name__)
14+
self.model = self._load_model(use_onnx=False)
15+
self.sample_rate = sample_rate
16+
self.threshold = threshold
17+
self.window_size = window_size
18+
19+
def _load_model(self, use_onnx: bool = False) -> torch.nn.Module:
20+
try:
21+
model, _ = torch.hub.load(
22+
repo_or_dir='silero-vad',
23+
model='silero_vad',
24+
source='local',
25+
onnx=use_onnx
26+
)
27+
except FileNotFoundError:
28+
self.logger.warning("Could not find local VAD model, downloading from GitHub!")
29+
model, _ = torch.hub.load(
30+
repo_or_dir='snakers4/silero-vad',
31+
model='silero_vad',
32+
source='github',
33+
onnx=use_onnx
34+
)
35+
return model
36+
37+
def process_chunk(self, chunk: bytes) -> bool:
38+
if len(chunk) != self.window_size:
39+
raise ValueError(f"Chunk size must be {self.window_size} bytes")
40+
chunk_array = torch.frombuffer(chunk, dtype=torch.int16).to(torch.float32) / self.INT16_NORM_CONST
41+
speech_prob = self.model(chunk_array, self.sample_rate).item()
42+
if speech_prob > self.threshold:
43+
return True
44+
return False
45+
46+
def reset_states(self) -> None:
47+
self.model.reset_states()

0 commit comments

Comments
 (0)